diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt index dcfed576bf..6b1a432b87 100644 --- a/.codespell-whitelist.txt +++ b/.codespell-whitelist.txt @@ -1 +1 @@ -dependant +dependant \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 93f6f4cabc..e5333e3e48 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -22,5 +22,5 @@ Please delete options that are not relevant. - [ ] My changes do not generate any new warnings - [ ] I have added tests to validate the effectiveness of my fix or the functionality of my new feature - [ ] Both new and existing unit tests pass successfully on my local environment by running `scripts/test-cov.sh` -- [ ] I have ensured that static analysis tests are passing by running `scripts/static-anaylysis.sh` +- [ ] I have ensured that static analysis tests are passing by running `scripts/static-analysis.sh` - [ ] I have included code examples to illustrate the modifications diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 6be65fa584..ddf783ded9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -84,7 +84,7 @@ jobs: key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-test-v03 - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[rabbit,kafka,confluent,nats,redis,testing] + run: pip install .[optionals,testing] - name: Install Pydantic v1 if: matrix.pydantic-version == 'pydantic-v1' run: pip install "pydantic>=1.10.0,<2.0.0" @@ -117,7 +117,7 @@ jobs: cache-dependency-path: pyproject.toml - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[nats,kafka,confluent,rabbit,redis,testing] orjson + run: pip install .[optionals,testing] orjson - run: mkdir coverage - name: Test run: bash scripts/test.sh -m"(slow and (not nats and not kafka and not confluent and not rabbit and not redis)) or (not nats and not kafka and not confluent and not rabbit and not redis)" @@ -144,7 +144,7 @@ jobs: cache-dependency-path: pyproject.toml - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[rabbit,kafka,confluent,nats,redis,testing] + run: pip install .[optionals,testing] - name: Test run: bash scripts/test.sh -m "(slow and (not nats and not kafka and not confluent and not rabbit and not redis)) or (not nats and not kafka and not confluent and not rabbit and not redis)" @@ -161,7 +161,7 @@ jobs: cache-dependency-path: pyproject.toml - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[rabbit,kafka,confluent,nats,redis,testing] + run: pip install .[optionals,testing] - name: Test run: bash scripts/test.sh -m "(slow and (not nats and not kafka and not confluent and not rabbit and not redis)) or (not nats and not kafka and not confluent and not rabbit and not redis)" @@ -194,7 +194,7 @@ jobs: cache-dependency-path: pyproject.toml - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[nats,kafka,confluent,rabbit,redis,testing] + run: pip install .[optionals,testing] - run: mkdir coverage - name: Test run: bash scripts/test.sh -m "(slow and kafka) or kafka" @@ -254,7 +254,7 @@ jobs: cache-dependency-path: pyproject.toml - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[nats,kafka,confluent,rabbit,redis,testing] + run: pip install .[optionals,testing] - run: mkdir coverage - name: Test run: bash scripts/test.sh -m "(slow and confluent) or confluent" @@ -303,7 +303,7 @@ jobs: cache-dependency-path: pyproject.toml - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[nats,kafka,confluent,rabbit,redis,testing] + run: pip install .[optionals,testing] - run: mkdir coverage - name: Test run: bash scripts/test.sh -m "(slow and rabbit) or rabbit" @@ -352,7 +352,7 @@ jobs: cache-dependency-path: pyproject.toml - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[nats,kafka,confluent,rabbit,redis,testing] + run: pip install .[optionals,testing] - run: mkdir coverage - name: Test run: bash scripts/test.sh -m "(slow and nats) or nats" @@ -401,7 +401,7 @@ jobs: cache-dependency-path: pyproject.toml - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' - run: pip install .[nats,kafka,confluent,rabbit,redis,testing] + run: pip install .[optionals,testing] - run: mkdir coverage - name: Test run: bash scripts/test.sh -m "(slow and redis) or redis" diff --git a/.secrets.baseline b/.secrets.baseline index 5d509637bd..4c3829ee62 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -128,7 +128,7 @@ "filename": "docs/docs/en/release.md", "hashed_secret": "35675e68f4b5af7b995d9205ad0fc43842f16450", "is_verified": false, - "line_number": 836, + "line_number": 1079, "is_secret": false } ], @@ -163,5 +163,5 @@ } ] }, - "generated_at": "2024-04-07T03:11:32Z" + "generated_at": "2024-04-23T11:41:19Z" } diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index 46fe5f4026..61f070b7bf 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -41,6 +41,7 @@ search: - [FastAPI Plugin](getting-started/integrations/fastapi/index.md) - [Django](getting-started/integrations/django/index.md) - [CLI commands](getting-started/cli/index.md) + - [OpenTelemetry](getting-started/opentelemetry/index.md) - [Logging](getting-started/logging.md) - [Config Management](getting-started/config/index.md) - [Task Scheduling](scheduling.md) @@ -134,6 +135,7 @@ search: - [KafkaRouter](public_api/faststream/kafka/KafkaRouter.md) - [TestApp](public_api/faststream/kafka/TestApp.md) - [TestKafkaBroker](public_api/faststream/kafka/TestKafkaBroker.md) + - [TopicPartition](public_api/faststream/kafka/TopicPartition.md) - nats - [AckPolicy](public_api/faststream/nats/AckPolicy.md) - [ConsumerConfig](public_api/faststream/nats/ConsumerConfig.md) @@ -438,6 +440,15 @@ search: - [ConsumerProtocol](api/faststream/confluent/message/ConsumerProtocol.md) - [FakeConsumer](api/faststream/confluent/message/FakeConsumer.md) - [KafkaMessage](api/faststream/confluent/message/KafkaMessage.md) + - opentelemetry + - [KafkaTelemetryMiddleware](api/faststream/confluent/opentelemetry/KafkaTelemetryMiddleware.md) + - middleware + - [KafkaTelemetryMiddleware](api/faststream/confluent/opentelemetry/middleware/KafkaTelemetryMiddleware.md) + - provider + - [BaseConfluentTelemetrySettingsProvider](api/faststream/confluent/opentelemetry/provider/BaseConfluentTelemetrySettingsProvider.md) + - [BatchConfluentTelemetrySettingsProvider](api/faststream/confluent/opentelemetry/provider/BatchConfluentTelemetrySettingsProvider.md) + - [ConfluentTelemetrySettingsProvider](api/faststream/confluent/opentelemetry/provider/ConfluentTelemetrySettingsProvider.md) + - [telemetry_attributes_provider_factory](api/faststream/confluent/opentelemetry/provider/telemetry_attributes_provider_factory.md) - parser - [AsyncConfluentParser](api/faststream/confluent/parser/AsyncConfluentParser.md) - publisher @@ -495,6 +506,7 @@ search: - [KafkaRouter](api/faststream/kafka/KafkaRouter.md) - [TestApp](api/faststream/kafka/TestApp.md) - [TestKafkaBroker](api/faststream/kafka/TestKafkaBroker.md) + - [TopicPartition](api/faststream/kafka/TopicPartition.md) - broker - [KafkaBroker](api/faststream/kafka/broker/KafkaBroker.md) - broker @@ -512,6 +524,15 @@ search: - [ConsumerProtocol](api/faststream/kafka/message/ConsumerProtocol.md) - [FakeConsumer](api/faststream/kafka/message/FakeConsumer.md) - [KafkaMessage](api/faststream/kafka/message/KafkaMessage.md) + - opentelemetry + - [KafkaTelemetryMiddleware](api/faststream/kafka/opentelemetry/KafkaTelemetryMiddleware.md) + - middleware + - [KafkaTelemetryMiddleware](api/faststream/kafka/opentelemetry/middleware/KafkaTelemetryMiddleware.md) + - provider + - [BaseKafkaTelemetrySettingsProvider](api/faststream/kafka/opentelemetry/provider/BaseKafkaTelemetrySettingsProvider.md) + - [BatchKafkaTelemetrySettingsProvider](api/faststream/kafka/opentelemetry/provider/BatchKafkaTelemetrySettingsProvider.md) + - [KafkaTelemetrySettingsProvider](api/faststream/kafka/opentelemetry/provider/KafkaTelemetrySettingsProvider.md) + - [telemetry_attributes_provider_factory](api/faststream/kafka/opentelemetry/provider/telemetry_attributes_provider_factory.md) - parser - [AioKafkaParser](api/faststream/kafka/parser/AioKafkaParser.md) - publisher @@ -594,6 +615,15 @@ search: - message - [NatsBatchMessage](api/faststream/nats/message/NatsBatchMessage.md) - [NatsMessage](api/faststream/nats/message/NatsMessage.md) + - opentelemetry + - [NatsTelemetryMiddleware](api/faststream/nats/opentelemetry/NatsTelemetryMiddleware.md) + - middleware + - [NatsTelemetryMiddleware](api/faststream/nats/opentelemetry/middleware/NatsTelemetryMiddleware.md) + - provider + - [BaseNatsTelemetrySettingsProvider](api/faststream/nats/opentelemetry/provider/BaseNatsTelemetrySettingsProvider.md) + - [NatsBatchTelemetrySettingsProvider](api/faststream/nats/opentelemetry/provider/NatsBatchTelemetrySettingsProvider.md) + - [NatsTelemetrySettingsProvider](api/faststream/nats/opentelemetry/provider/NatsTelemetrySettingsProvider.md) + - [telemetry_attributes_provider_factory](api/faststream/nats/opentelemetry/provider/telemetry_attributes_provider_factory.md) - parser - [BatchParser](api/faststream/nats/parser/BatchParser.md) - [JsParser](api/faststream/nats/parser/JsParser.md) @@ -636,6 +666,16 @@ search: - [PatchedMessage](api/faststream/nats/testing/PatchedMessage.md) - [TestNatsBroker](api/faststream/nats/testing/TestNatsBroker.md) - [build_message](api/faststream/nats/testing/build_message.md) + - opentelemetry + - [TelemetryMiddleware](api/faststream/opentelemetry/TelemetryMiddleware.md) + - [TelemetrySettingsProvider](api/faststream/opentelemetry/TelemetrySettingsProvider.md) + - consts + - [MessageAction](api/faststream/opentelemetry/consts/MessageAction.md) + - middleware + - [BaseTelemetryMiddleware](api/faststream/opentelemetry/middleware/BaseTelemetryMiddleware.md) + - [TelemetryMiddleware](api/faststream/opentelemetry/middleware/TelemetryMiddleware.md) + - provider + - [TelemetrySettingsProvider](api/faststream/opentelemetry/provider/TelemetrySettingsProvider.md) - rabbit - [ExchangeType](api/faststream/rabbit/ExchangeType.md) - [RabbitBroker](api/faststream/rabbit/RabbitBroker.md) @@ -662,6 +702,12 @@ search: - [RabbitRouter](api/faststream/rabbit/fastapi/router/RabbitRouter.md) - message - [RabbitMessage](api/faststream/rabbit/message/RabbitMessage.md) + - opentelemetry + - [RabbitTelemetryMiddleware](api/faststream/rabbit/opentelemetry/RabbitTelemetryMiddleware.md) + - middleware + - [RabbitTelemetryMiddleware](api/faststream/rabbit/opentelemetry/middleware/RabbitTelemetryMiddleware.md) + - provider + - [RabbitTelemetrySettingsProvider](api/faststream/rabbit/opentelemetry/provider/RabbitTelemetrySettingsProvider.md) - parser - [AioPikaParser](api/faststream/rabbit/parser/AioPikaParser.md) - publisher @@ -746,6 +792,12 @@ search: - [StreamMessage](api/faststream/redis/message/StreamMessage.md) - [UnifyRedisDict](api/faststream/redis/message/UnifyRedisDict.md) - [UnifyRedisMessage](api/faststream/redis/message/UnifyRedisMessage.md) + - opentelemetry + - [RedisTelemetryMiddleware](api/faststream/redis/opentelemetry/RedisTelemetryMiddleware.md) + - middleware + - [RedisTelemetryMiddleware](api/faststream/redis/opentelemetry/middleware/RedisTelemetryMiddleware.md) + - provider + - [RedisTelemetrySettingsProvider](api/faststream/redis/opentelemetry/provider/RedisTelemetrySettingsProvider.md) - parser - [RawMessage](api/faststream/redis/parser/RawMessage.md) - [RedisBatchListParser](api/faststream/redis/parser/RedisBatchListParser.md) diff --git a/docs/docs/assets/img/distributed-trace.png b/docs/docs/assets/img/distributed-trace.png new file mode 100644 index 0000000000..2b9e89de5f Binary files /dev/null and b/docs/docs/assets/img/distributed-trace.png differ diff --git a/docs/docs/assets/img/simple-trace.png b/docs/docs/assets/img/simple-trace.png new file mode 100644 index 0000000000..4aeb1868d6 Binary files /dev/null and b/docs/docs/assets/img/simple-trace.png differ diff --git a/docs/docs/en/api/faststream/confluent/opentelemetry/KafkaTelemetryMiddleware.md b/docs/docs/en/api/faststream/confluent/opentelemetry/KafkaTelemetryMiddleware.md new file mode 100644 index 0000000000..743c494591 --- /dev/null +++ b/docs/docs/en/api/faststream/confluent/opentelemetry/KafkaTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.confluent.opentelemetry.KafkaTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/confluent/opentelemetry/middleware/KafkaTelemetryMiddleware.md b/docs/docs/en/api/faststream/confluent/opentelemetry/middleware/KafkaTelemetryMiddleware.md new file mode 100644 index 0000000000..b34265dfbb --- /dev/null +++ b/docs/docs/en/api/faststream/confluent/opentelemetry/middleware/KafkaTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.confluent.opentelemetry.middleware.KafkaTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/confluent/opentelemetry/provider/BaseConfluentTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/confluent/opentelemetry/provider/BaseConfluentTelemetrySettingsProvider.md new file mode 100644 index 0000000000..730662fae5 --- /dev/null +++ b/docs/docs/en/api/faststream/confluent/opentelemetry/provider/BaseConfluentTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.confluent.opentelemetry.provider.BaseConfluentTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/confluent/opentelemetry/provider/BatchConfluentTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/confluent/opentelemetry/provider/BatchConfluentTelemetrySettingsProvider.md new file mode 100644 index 0000000000..a6db133484 --- /dev/null +++ b/docs/docs/en/api/faststream/confluent/opentelemetry/provider/BatchConfluentTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.confluent.opentelemetry.provider.BatchConfluentTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/confluent/opentelemetry/provider/ConfluentTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/confluent/opentelemetry/provider/ConfluentTelemetrySettingsProvider.md new file mode 100644 index 0000000000..2c5242e6e5 --- /dev/null +++ b/docs/docs/en/api/faststream/confluent/opentelemetry/provider/ConfluentTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.confluent.opentelemetry.provider.ConfluentTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/confluent/opentelemetry/provider/telemetry_attributes_provider_factory.md b/docs/docs/en/api/faststream/confluent/opentelemetry/provider/telemetry_attributes_provider_factory.md new file mode 100644 index 0000000000..7dd0e1d0fd --- /dev/null +++ b/docs/docs/en/api/faststream/confluent/opentelemetry/provider/telemetry_attributes_provider_factory.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.confluent.opentelemetry.provider.telemetry_attributes_provider_factory diff --git a/docs/docs/en/api/faststream/kafka/TopicPartition.md b/docs/docs/en/api/faststream/kafka/TopicPartition.md new file mode 100644 index 0000000000..41fbd7f624 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/TopicPartition.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: aiokafka.structs.TopicPartition diff --git a/docs/docs/en/api/faststream/kafka/opentelemetry/KafkaTelemetryMiddleware.md b/docs/docs/en/api/faststream/kafka/opentelemetry/KafkaTelemetryMiddleware.md new file mode 100644 index 0000000000..02fb4805ac --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/opentelemetry/KafkaTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.opentelemetry.KafkaTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/kafka/opentelemetry/middleware/KafkaTelemetryMiddleware.md b/docs/docs/en/api/faststream/kafka/opentelemetry/middleware/KafkaTelemetryMiddleware.md new file mode 100644 index 0000000000..aba78378f2 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/opentelemetry/middleware/KafkaTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.opentelemetry.middleware.KafkaTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/kafka/opentelemetry/provider/BaseKafkaTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/kafka/opentelemetry/provider/BaseKafkaTelemetrySettingsProvider.md new file mode 100644 index 0000000000..5cb13be947 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/opentelemetry/provider/BaseKafkaTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.opentelemetry.provider.BaseKafkaTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/kafka/opentelemetry/provider/BatchKafkaTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/kafka/opentelemetry/provider/BatchKafkaTelemetrySettingsProvider.md new file mode 100644 index 0000000000..d3d7080509 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/opentelemetry/provider/BatchKafkaTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.opentelemetry.provider.BatchKafkaTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/kafka/opentelemetry/provider/KafkaTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/kafka/opentelemetry/provider/KafkaTelemetrySettingsProvider.md new file mode 100644 index 0000000000..0859c0df3d --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/opentelemetry/provider/KafkaTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.opentelemetry.provider.KafkaTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/kafka/opentelemetry/provider/telemetry_attributes_provider_factory.md b/docs/docs/en/api/faststream/kafka/opentelemetry/provider/telemetry_attributes_provider_factory.md new file mode 100644 index 0000000000..3b2a1ad394 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/opentelemetry/provider/telemetry_attributes_provider_factory.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.opentelemetry.provider.telemetry_attributes_provider_factory diff --git a/docs/docs/en/api/faststream/nats/opentelemetry/NatsTelemetryMiddleware.md b/docs/docs/en/api/faststream/nats/opentelemetry/NatsTelemetryMiddleware.md new file mode 100644 index 0000000000..e72f2de8ab --- /dev/null +++ b/docs/docs/en/api/faststream/nats/opentelemetry/NatsTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.nats.opentelemetry.NatsTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/nats/opentelemetry/middleware/NatsTelemetryMiddleware.md b/docs/docs/en/api/faststream/nats/opentelemetry/middleware/NatsTelemetryMiddleware.md new file mode 100644 index 0000000000..b2bb226585 --- /dev/null +++ b/docs/docs/en/api/faststream/nats/opentelemetry/middleware/NatsTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.nats.opentelemetry.middleware.NatsTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/nats/opentelemetry/provider/BaseNatsTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/nats/opentelemetry/provider/BaseNatsTelemetrySettingsProvider.md new file mode 100644 index 0000000000..d6626c537d --- /dev/null +++ b/docs/docs/en/api/faststream/nats/opentelemetry/provider/BaseNatsTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.nats.opentelemetry.provider.BaseNatsTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/nats/opentelemetry/provider/NatsBatchTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/nats/opentelemetry/provider/NatsBatchTelemetrySettingsProvider.md new file mode 100644 index 0000000000..045996125a --- /dev/null +++ b/docs/docs/en/api/faststream/nats/opentelemetry/provider/NatsBatchTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.nats.opentelemetry.provider.NatsBatchTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/nats/opentelemetry/provider/NatsTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/nats/opentelemetry/provider/NatsTelemetrySettingsProvider.md new file mode 100644 index 0000000000..b58590c4fa --- /dev/null +++ b/docs/docs/en/api/faststream/nats/opentelemetry/provider/NatsTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.nats.opentelemetry.provider.NatsTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/nats/opentelemetry/provider/telemetry_attributes_provider_factory.md b/docs/docs/en/api/faststream/nats/opentelemetry/provider/telemetry_attributes_provider_factory.md new file mode 100644 index 0000000000..200d333e0b --- /dev/null +++ b/docs/docs/en/api/faststream/nats/opentelemetry/provider/telemetry_attributes_provider_factory.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.nats.opentelemetry.provider.telemetry_attributes_provider_factory diff --git a/docs/docs/en/api/faststream/opentelemetry/TelemetryMiddleware.md b/docs/docs/en/api/faststream/opentelemetry/TelemetryMiddleware.md new file mode 100644 index 0000000000..914f134e60 --- /dev/null +++ b/docs/docs/en/api/faststream/opentelemetry/TelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.opentelemetry.TelemetryMiddleware diff --git a/docs/docs/en/api/faststream/opentelemetry/TelemetrySettingsProvider.md b/docs/docs/en/api/faststream/opentelemetry/TelemetrySettingsProvider.md new file mode 100644 index 0000000000..7ca8b2cb6d --- /dev/null +++ b/docs/docs/en/api/faststream/opentelemetry/TelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.opentelemetry.TelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/opentelemetry/consts/MessageAction.md b/docs/docs/en/api/faststream/opentelemetry/consts/MessageAction.md new file mode 100644 index 0000000000..cd58706774 --- /dev/null +++ b/docs/docs/en/api/faststream/opentelemetry/consts/MessageAction.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.opentelemetry.consts.MessageAction diff --git a/docs/docs/en/api/faststream/opentelemetry/middleware/BaseTelemetryMiddleware.md b/docs/docs/en/api/faststream/opentelemetry/middleware/BaseTelemetryMiddleware.md new file mode 100644 index 0000000000..64a7b4a501 --- /dev/null +++ b/docs/docs/en/api/faststream/opentelemetry/middleware/BaseTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.opentelemetry.middleware.BaseTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/opentelemetry/middleware/TelemetryMiddleware.md b/docs/docs/en/api/faststream/opentelemetry/middleware/TelemetryMiddleware.md new file mode 100644 index 0000000000..f019b3ad61 --- /dev/null +++ b/docs/docs/en/api/faststream/opentelemetry/middleware/TelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.opentelemetry.middleware.TelemetryMiddleware diff --git a/docs/docs/en/api/faststream/opentelemetry/provider/TelemetrySettingsProvider.md b/docs/docs/en/api/faststream/opentelemetry/provider/TelemetrySettingsProvider.md new file mode 100644 index 0000000000..0fefe1c0ef --- /dev/null +++ b/docs/docs/en/api/faststream/opentelemetry/provider/TelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.opentelemetry.provider.TelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/rabbit/opentelemetry/RabbitTelemetryMiddleware.md b/docs/docs/en/api/faststream/rabbit/opentelemetry/RabbitTelemetryMiddleware.md new file mode 100644 index 0000000000..7d5ef3de27 --- /dev/null +++ b/docs/docs/en/api/faststream/rabbit/opentelemetry/RabbitTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.rabbit.opentelemetry.RabbitTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/rabbit/opentelemetry/middleware/RabbitTelemetryMiddleware.md b/docs/docs/en/api/faststream/rabbit/opentelemetry/middleware/RabbitTelemetryMiddleware.md new file mode 100644 index 0000000000..e86771a8ba --- /dev/null +++ b/docs/docs/en/api/faststream/rabbit/opentelemetry/middleware/RabbitTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.rabbit.opentelemetry.middleware.RabbitTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/rabbit/opentelemetry/provider/RabbitTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/rabbit/opentelemetry/provider/RabbitTelemetrySettingsProvider.md new file mode 100644 index 0000000000..ba6742ac90 --- /dev/null +++ b/docs/docs/en/api/faststream/rabbit/opentelemetry/provider/RabbitTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.rabbit.opentelemetry.provider.RabbitTelemetrySettingsProvider diff --git a/docs/docs/en/api/faststream/redis/opentelemetry/RedisTelemetryMiddleware.md b/docs/docs/en/api/faststream/redis/opentelemetry/RedisTelemetryMiddleware.md new file mode 100644 index 0000000000..537a2dc7b9 --- /dev/null +++ b/docs/docs/en/api/faststream/redis/opentelemetry/RedisTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.redis.opentelemetry.RedisTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/redis/opentelemetry/middleware/RedisTelemetryMiddleware.md b/docs/docs/en/api/faststream/redis/opentelemetry/middleware/RedisTelemetryMiddleware.md new file mode 100644 index 0000000000..4c0febf261 --- /dev/null +++ b/docs/docs/en/api/faststream/redis/opentelemetry/middleware/RedisTelemetryMiddleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.redis.opentelemetry.middleware.RedisTelemetryMiddleware diff --git a/docs/docs/en/api/faststream/redis/opentelemetry/provider/RedisTelemetrySettingsProvider.md b/docs/docs/en/api/faststream/redis/opentelemetry/provider/RedisTelemetrySettingsProvider.md new file mode 100644 index 0000000000..26e7859c34 --- /dev/null +++ b/docs/docs/en/api/faststream/redis/opentelemetry/provider/RedisTelemetrySettingsProvider.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.redis.opentelemetry.provider.RedisTelemetrySettingsProvider diff --git a/docs/docs/en/getting-started/opentelemetry/index.md b/docs/docs/en/getting-started/opentelemetry/index.md new file mode 100644 index 0000000000..44e7fe9013 --- /dev/null +++ b/docs/docs/en/getting-started/opentelemetry/index.md @@ -0,0 +1,114 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 10 +--- + +# OpenTelemetry + +**OpenTelemetry** is an open-source observability framework designed to provide a unified standard for collecting and exporting telemetry data such as traces, metrics, and logs. It aims to make observability a built-in feature of software development, simplifying the integration and standardization of telemetry data across various services. For more details, you can read the official [OpenTelemetry documentation](https://opentelemetry.io/){.external-link target="_blank"}. + +## Tracing + +Tracing is a form of observability that tracks the flow of requests as they move through various services in a distributed system. It provides insights into the interactions between services, highlighting performance bottlenecks and errors. The result of implementing tracing is a detailed map of the service interactions, often visualized as a trace diagram. This helps developers understand the behavior and performance of their applications. For an in-depth explanation, refer to the [OpenTelemetry tracing specification](https://opentelemetry.io/docs/concepts/signals/traces/){.external-link target="_blank"}. + +![HTML-page](../../../assets/img/simple-trace.png){ loading=lazy } +`Visualized via Grafana and Tempo` + +This trace is derived from this relationship between handlers: + +```python linenums="1" +@broker.subscriber("first") +@broker.publisher("second") +async def first_handler(msg: str): + await asyncio.sleep(0.1) + return msg + + +@broker.subscriber("second") +@broker.publisher("third") +async def second_handler(msg: str): + await asyncio.sleep(0.05) + return msg + + +@broker.subscriber("third") +async def third_handler(msg: str): + await asyncio.sleep(0.075) +``` + +## FastStream Tracing + +**OpenTelemetry** tracing support in **FastStream** adheres to the [semantic conventions for messaging systems](https://opentelemetry.io/docs/specs/semconv/messaging/){.external-link target="_blank"}. + +To add a trace to your broker, you need to: + +1. Install `FastStream` with `opentelemetry-sdk` + + ```shell + pip install faststream[otel] + ``` + +2. Configure `TracerProvider` + + ```python linenums="1" hl_lines="5-7" + from opentelemetry import trace + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + + resource = Resource.create(attributes={"service.name": "faststream"}) + tracer_provider = TracerProvider(resource=resource) + trace.set_tracer_provider(tracer_provider) + ``` + +3. Add `TelemetryMiddleware` to your broker + + {!> includes/getting_started/opentelemetry/1.md !} + +### Exporting + +To export traces, you must select and configure an exporter yourself: + +* [opentelemetry-exporter-jaeger](https://pypi.org/project/opentelemetry-exporter-jaeger/){.external-link target="_blank"} to export to **Jaeger** +* [opentelemetry-exporter-otlp](https://pypi.org/project/opentelemetry-exporter-otlp/){.external-link target="_blank"} for export via **gRPC** or **HTTP** +* ``InMemorySpanExporter`` from ``opentelemetry.sdk.trace.export.in_memory_span_exporter`` for local tests + +There are other exporters. + +Configuring the export of traces via `opentelemetry-exporter-otlp`: + +```python linenums="1" hl_lines="4-6" +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +exporter = OTLPSpanExporter(endpoint="http://127.0.0.1:4317") +processor = BatchSpanProcessor(exporter) +tracer_provider.add_span_processor(processor) +``` + +### Visualization + +To visualize traces, you can send them to a backend system that supports distributed tracing, such as **Jaeger**, **Zipkin**, or **Grafana Tempo**. These systems provide a user interface to visualize and analyze traces. + +* **Jaeger**: You can run **Jaeger** using Docker and configure your **OpenTelemetry** middleware to send traces to **Jaeger**. For more details, see the [Jaeger documentation](https://www.jaegertracing.io/){.external-link target="_blank"}. +* **Zipkin**: Similar to **Jaeger**, you can run **Zipkin** using **Docker** and configure the **OpenTelemetry** middleware accordingly. For more details, see the [Zipkin documentation](https://zipkin.io/){.external-link target="_blank"}. +* **Grafana Tempo**: **Grafana Tempo** is a high-scale distributed tracing backend. You can configure **OpenTelemetry** to export traces to **Tempo**, which can then be visualized using **Grafana**. For more details, see the [Grafana Tempo documentation](https://grafana.com/docs/tempo/latest/){.external-link target="_blank"}. + +## Example + +To see how to set up, visualize, and configure tracing for **FastStream** services, go to [example](https://github.com/draincoder/faststream-monitoring){.external-link target="_blank"}. + +An example includes: + +* Three `FastStream` services +* Exporting traces to `Grafana Tempo` via `gRPC` +* Visualization of traces via `Grafana` +* Examples with custom spans +* Configured `docker-compose` with the entire infrastructure + +![HTML-page](../../../assets/img/distributed-trace.png){ loading=lazy } +`Visualized via Grafana and Tempo` diff --git a/docs/docs/en/release.md b/docs/docs/en/release.md index c56b3eea34..ce6828b2fa 100644 --- a/docs/docs/en/release.md +++ b/docs/docs/en/release.md @@ -12,6 +12,131 @@ hide: --- # Release Notes +## 0.5.7 + +### What's Changed + +Finally, FastStream supports [OpenTelemetry](https://opentelemetry.io/) in a native way to collect the full trace of your services! Big thanks for @draincoder for that! + +First of all you need to install required dependencies to support OpenTelemetry: + +```bash +pip install faststream[otel] +``` + +Then you can just add a middleware for your broker and that's it! + +```python +from faststream import FastStream +from faststream.nats import NatsBroker +from faststream.nats.opentelemetry import NatsTelemetryMiddleware + +broker = NatsBroker( + middlewares=( + NatsTelemetryMiddleware(), + ) +) +app = FastStream(broker) +``` + +To find detailt information just visit our documentation aboout [telemetry](https://faststream.airt.ai/latest/getting-started/opentelemetry/) + +P.S. The release includes basic OpenTelemetry support - messages tracing & basic metrics. Baggage support and correct spans linking in batch processing case will be added soon. + +* fix: serialize TestClient rpc output to mock the real message by @Lancetnik in https://github.com/airtai/faststream/pull/1452 +* feature (#916): Observability by @draincoder in https://github.com/airtai/faststream/pull/1398 + +### New Contributors +* @draincoder made their first contribution in https://github.com/airtai/faststream/pull/1398 + +**Full Changelog**: https://github.com/airtai/faststream/compare/0.5.6...0.5.7 + +## 0.5.6 + +### What's Changed + +* feature: add --factory param by [@Sehat1137](https://github.com/Sehat1137){.external-link target="_blank"} in [#1440](https://github.com/airtai/faststream/pull/1440){.external-link target="_blank"} +* feat: add RMQ channels options, support for prefix for routing_key, a… by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1448](https://github.com/airtai/faststream/pull/1448){.external-link target="_blank"} +* feature: Add `from faststream.rabbit.annotations import Connection, Channel` shortcuts +* Bugfix: RabbitMQ RabbitRouter prefix now affects to queue routing key as well +* Feature (close #1402): add `broker.add_middleware` public API to append a middleware to already created broker +* Feature: add `RabbitBroker(channel_number: int, publisher_confirms: bool, on_return_raises: bool)` options to setup channel settings +* Feature (close #1447): add `StreamMessage.batch_headers` attribute to provide with access to whole batch messages headers + +### New Contributors + +* [@Sehat1137](https://github.com/Sehat1137){.external-link target="_blank"} made their first contribution in [#1440](https://github.com/airtai/faststream/pull/1440){.external-link target="_blank"} + +**Full Changelog**: [#0.5.5...0.5.6](https://github.com/airtai/faststream/compare/0.5.5...0.5.6){.external-link target="_blank"} + +## 0.5.5 + +### What's Changed + +Add support for explicit partition assignment in aiokafka `KafkaBroker` (special thanks to @spataphore1337): + +```python +from faststream import FastStream +from faststream.kafka import KafkaBroker, TopicPartition + +broker = KafkaBroker() + +topic_partition_fisrt = TopicPartition("my_topic", 1) +topic_partition_second = TopicPartition("my_topic", 2) + +@broker.subscribe(partitions=[topic_partition_fisrt, topic_partition_second]) +async def some_consumer(msg): + ... +``` + +* Update Release Notes for 0.5.4 by @faststream-release-notes-updater in [#1421](https://github.com/airtai/faststream/pull/1421){.external-link target="_blank"} +* feature: manual partition assignment to Kafka by [@spataphore1337](https://github.com/spataphore1337){.external-link target="_blank"} in [#1422](https://github.com/airtai/faststream/pull/1422){.external-link target="_blank"} +* Chore/update deps by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1429](https://github.com/airtai/faststream/pull/1429){.external-link target="_blank"} +* Fix/correct dynamic subscriber registration by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1433](https://github.com/airtai/faststream/pull/1433){.external-link target="_blank"} +* chore: bump version by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1435](https://github.com/airtai/faststream/pull/1435){.external-link target="_blank"} + + +**Full Changelog**: [#0.5.4...0.5.5](https://github.com/airtai/faststream/compare/0.5.4...0.5.5){.external-link target="_blank"} + +## 0.5.4 + +### What's Changed + +* Update Release Notes for 0.5.3 by @faststream-release-notes-updater in [#1400](https://github.com/airtai/faststream/pull/1400){.external-link target="_blank"} +* fix (#1415): raise SetupError if rpc and reply_to are using in TestCL… by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1419](https://github.com/airtai/faststream/pull/1419){.external-link target="_blank"} +* Chore/update deps2 by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1418](https://github.com/airtai/faststream/pull/1418){.external-link target="_blank"} +* refactor: correct security with kwarg params merging by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1417](https://github.com/airtai/faststream/pull/1417){.external-link target="_blank"} +* fix (#1414): correct Message.ack error processing by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1420](https://github.com/airtai/faststream/pull/1420){.external-link target="_blank"} + +**Full Changelog**: [#0.5.3...0.5.4](https://github.com/airtai/faststream/compare/0.5.3...0.5.4){.external-link target="_blank"} + +## 0.5.3 + +### What's Changed +* Update Release Notes for 0.5.2 by @faststream-release-notes-updater in [#1382](https://github.com/airtai/faststream/pull/1382){.external-link target="_blank"} +* Fix/setup at broker connection instead of starting by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1385](https://github.com/airtai/faststream/pull/1385){.external-link target="_blank"} +* Tests/add path tests by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1388](https://github.com/airtai/faststream/pull/1388){.external-link target="_blank"} +* Fix/path with router prefix by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1395](https://github.com/airtai/faststream/pull/1395){.external-link target="_blank"} +* chore: update dependencies by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1396](https://github.com/airtai/faststream/pull/1396){.external-link target="_blank"} +* chore: bump version by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1397](https://github.com/airtai/faststream/pull/1397){.external-link target="_blank"} +* chore: polishing by [@davorrunje](https://github.com/davorrunje){.external-link target="_blank"} in [#1399](https://github.com/airtai/faststream/pull/1399){.external-link target="_blank"} + + +**Full Changelog**: [#0.5.2...0.5.3](https://github.com/airtai/faststream/compare/0.5.2...0.5.3){.external-link target="_blank"} + +## 0.5.2 + +### What's Changed + +Just a little bugfix patch. Fixes #1379 and #1376. + +* Update Release Notes for 0.5.1 by @faststream-release-notes-updater in [#1378](https://github.com/airtai/faststream/pull/1378){.external-link target="_blank"} +* Tests/fastapi background by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1380](https://github.com/airtai/faststream/pull/1380){.external-link target="_blank"} +* Fix/0.5.2 by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1381](https://github.com/airtai/faststream/pull/1381){.external-link target="_blank"} + + +**Full Changelog**: [#0.5.1...0.5.2](https://github.com/airtai/faststream/compare/0.5.1...0.5.2){.external-link target="_blank"} + ## 0.5.1 ### What's Changed @@ -29,7 +154,7 @@ We already have some fixes related to `RedisBroker` (#1375, #1376) and some new include_in_schema=False, ) ``` - + 2. `KafkaBroker().subscriber(...)` now consumes `aiokafka.ConsumerRebalanceListener` object. You can find more information about it in the official [**aiokafka** doc](https://aiokafka.readthedocs.io/en/stable/consumer.html?highlight=subscribe#topic-subscription-by-pattern) @@ -37,7 +162,7 @@ You can find more information about it in the official [**aiokafka** doc](https: ```python broker = KafkaBroker() - + broker.subscriber(..., listener=MyRebalancer()) ``` @@ -88,37 +213,37 @@ subscriber = broker.subscriber("test") @subscriber(filter = lambda msg: msg.content_type == "application/json") async def handler(msg: dict[str, Any]): ... - + @subscriber() async def handler(msg: dict[str, Any]): ... ``` - + This is the preferred syntax for [filtering](https://faststream.airt.ai/latest/getting-started/subscription/filtering/) now (the old one will be removed in `0.6.0`) - + 3. The `router.publisher()` function now returns the correct `Publisher` object you can use later (after broker startup). - + ```python publisher = router.publisher("test") - + @router.subscriber("in") async def handler(): await publisher.publish("msg") ``` - + (Until `0.5.0` you could use it in this way with `broker.publisher` only) - + 4. A list of `middlewares` can be passed to a `broker.publisher` as well: - + ```python broker = Broker(..., middlewares=()) - + @broker.subscriber(..., middlewares=()) @broker.publisher(..., middlewares=()) # new feature async def handler(): ... ``` - + 5. Broker-level middlewares now affect all ways to publish a message, so you can encode application outgoing messages here. 6. ⚠️ BREAKING CHANGE ⚠️ : both `subscriber` and `publisher` middlewares should be async context manager type @@ -182,7 +307,7 @@ await subscriber.close() * close #568 * close #1303 * close #1287 - * feat #607 + * feat #607 * Generate docs and linter fixes by @davorrunje in https://github.com/airtai/faststream/pull/1348 * Fix types by @davorrunje in https://github.com/airtai/faststream/pull/1349 * chore: update dependencies by @Lancetnik in https://github.com/airtai/faststream/pull/1358 diff --git a/docs/docs/navigation_template.txt b/docs/docs/navigation_template.txt index 87df76aa6c..fa23f9c3c5 100644 --- a/docs/docs/navigation_template.txt +++ b/docs/docs/navigation_template.txt @@ -41,6 +41,7 @@ search: - [FastAPI Plugin](getting-started/integrations/fastapi/index.md) - [Django](getting-started/integrations/django/index.md) - [CLI commands](getting-started/cli/index.md) + - [OpenTelemetry](getting-started/opentelemetry/index.md) - [Logging](getting-started/logging.md) - [Config Management](getting-started/config/index.md) - [Task Scheduling](scheduling.md) diff --git a/docs/docs_src/getting_started/opentelemetry/__init__.py b/docs/docs_src/getting_started/opentelemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/docs_src/getting_started/opentelemetry/confluent_telemetry.py b/docs/docs_src/getting_started/opentelemetry/confluent_telemetry.py new file mode 100644 index 0000000000..e9e3175d6d --- /dev/null +++ b/docs/docs_src/getting_started/opentelemetry/confluent_telemetry.py @@ -0,0 +1,10 @@ +from faststream import FastStream +from faststream.confluent import KafkaBroker +from faststream.confluent.opentelemetry import KafkaTelemetryMiddleware + +broker = KafkaBroker( + middlewares=( + KafkaTelemetryMiddleware(tracer_provider=tracer_provider) + ) +) +app = FastStream(broker) diff --git a/docs/docs_src/getting_started/opentelemetry/kafka_telemetry.py b/docs/docs_src/getting_started/opentelemetry/kafka_telemetry.py new file mode 100644 index 0000000000..4bbfd9d9d8 --- /dev/null +++ b/docs/docs_src/getting_started/opentelemetry/kafka_telemetry.py @@ -0,0 +1,10 @@ +from faststream import FastStream +from faststream.kafka import KafkaBroker +from faststream.kafka.opentelemetry import KafkaTelemetryMiddleware + +broker = KafkaBroker( + middlewares=( + KafkaTelemetryMiddleware(tracer_provider=tracer_provider), + ) +) +app = FastStream(broker) diff --git a/docs/docs_src/getting_started/opentelemetry/nats_telemetry.py b/docs/docs_src/getting_started/opentelemetry/nats_telemetry.py new file mode 100644 index 0000000000..f503e22050 --- /dev/null +++ b/docs/docs_src/getting_started/opentelemetry/nats_telemetry.py @@ -0,0 +1,10 @@ +from faststream import FastStream +from faststream.nats import NatsBroker +from faststream.nats.opentelemetry import NatsTelemetryMiddleware + +broker = NatsBroker( + middlewares=( + NatsTelemetryMiddleware(tracer_provider=tracer_provider), + ) +) +app = FastStream(broker) diff --git a/docs/docs_src/getting_started/opentelemetry/rabbit_telemetry.py b/docs/docs_src/getting_started/opentelemetry/rabbit_telemetry.py new file mode 100644 index 0000000000..4dea2f919f --- /dev/null +++ b/docs/docs_src/getting_started/opentelemetry/rabbit_telemetry.py @@ -0,0 +1,10 @@ +from faststream import FastStream +from faststream.rabbit import RabbitBroker +from faststream.rabbit.opentelemetry import RabbitTelemetryMiddleware + +broker = RabbitBroker( + middlewares=( + RabbitTelemetryMiddleware(tracer_provider=tracer_provider), + ) +) +app = FastStream(broker) diff --git a/docs/docs_src/getting_started/opentelemetry/redis_telemetry.py b/docs/docs_src/getting_started/opentelemetry/redis_telemetry.py new file mode 100644 index 0000000000..2de8174264 --- /dev/null +++ b/docs/docs_src/getting_started/opentelemetry/redis_telemetry.py @@ -0,0 +1,10 @@ +from faststream import FastStream +from faststream.redis import RedisBroker +from faststream.redis.opentelemetry import RedisTelemetryMiddleware + +broker = RedisBroker( + middlewares=( + RedisTelemetryMiddleware(tracer_provider=tracer_provider), + ) +) +app = FastStream(broker) diff --git a/docs/includes/getting_started/opentelemetry/1.md b/docs/includes/getting_started/opentelemetry/1.md new file mode 100644 index 0000000000..5ddf58d192 --- /dev/null +++ b/docs/includes/getting_started/opentelemetry/1.md @@ -0,0 +1,24 @@ +=== "AIOKafka" + ```python linenums="1" hl_lines="7" + {!> docs_src/getting_started/opentelemetry/kafka_telemetry.py!} + ``` + +=== "Confluent" + ```python linenums="1" hl_lines="7" + {!> docs_src/getting_started/opentelemetry/confluent_telemetry.py!} + ``` + +=== "RabbitMQ" + ```python linenums="1" hl_lines="7" + {!> docs_src/getting_started/opentelemetry/rabbit_telemetry.py!} + ``` + +=== "NATS" + ```python linenums="1" hl_lines="7" + {!> docs_src/getting_started/opentelemetry/nats_telemetry.py!} + ``` + +=== "Redis" + ```python linenums="1" hl_lines="7" + {!> docs_src/getting_started/opentelemetry/redis_telemetry.py!} + ``` diff --git a/faststream/__about__.py b/faststream/__about__.py index 1a82140e1e..6a9efa082f 100644 --- a/faststream/__about__.py +++ b/faststream/__about__.py @@ -1,6 +1,6 @@ """Simple and fast framework to create message brokers based microservices.""" -__version__ = "0.5.2" +__version__ = "0.5.7" SERVICE_NAME = f"faststream-{__version__}" diff --git a/faststream/asyncapi/site.py b/faststream/asyncapi/site.py index 73184f9bb4..fcc0aefea6 100644 --- a/faststream/asyncapi/site.py +++ b/faststream/asyncapi/site.py @@ -102,7 +102,7 @@ def serve_app( ) -> None: """Serve the HTTPServer with AsyncAPI schema.""" logger.info(f"HTTPServer running on http://{host}:{port} (Press CTRL+C to quit)") - logger.warn("Please, do not use it in production.") + logger.warning("Please, do not use it in production.") server.HTTPServer( (host, port), diff --git a/faststream/broker/acknowledgement_watcher.py b/faststream/broker/acknowledgement_watcher.py index 4ecb6ad4b3..dabc6eb87f 100644 --- a/faststream/broker/acknowledgement_watcher.py +++ b/faststream/broker/acknowledgement_watcher.py @@ -126,11 +126,13 @@ def __init__( self, message: "StreamMessage[MsgType]", watcher: BaseWatcher, + logger: Optional["LoggerProto"] = None, **extra_options: Any, ) -> None: self.watcher = watcher self.message = message self.extra_options = extra_options + self.logger = logger async def __aenter__(self) -> None: self.watcher.add(self.message.message_id) @@ -172,15 +174,29 @@ async def __aexit__( return not is_test_env() async def __ack(self) -> None: - await self.message.ack(**self.extra_options) - self.watcher.remove(self.message.message_id) + try: + await self.message.ack(**self.extra_options) + except Exception as er: + if self.logger is not None: + self.logger.log(logging.ERROR, er, exc_info=er) + else: + self.watcher.remove(self.message.message_id) async def __nack(self) -> None: - await self.message.nack(**self.extra_options) + try: + await self.message.nack(**self.extra_options) + except Exception as er: + if self.logger is not None: + self.logger.log(logging.ERROR, er, exc_info=er) async def __reject(self) -> None: - await self.message.reject(**self.extra_options) - self.watcher.remove(self.message.message_id) + try: + await self.message.reject(**self.extra_options) + except Exception as er: + if self.logger is not None: + self.logger.log(logging.ERROR, er, exc_info=er) + else: + self.watcher.remove(self.message.message_id) def get_watcher( diff --git a/faststream/broker/core/abc.py b/faststream/broker/core/abc.py index 1a49e26843..eb1a49bb7b 100644 --- a/faststream/broker/core/abc.py +++ b/faststream/broker/core/abc.py @@ -46,6 +46,19 @@ def __init__( self._parser = parser self._decoder = decoder + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + """Append BrokerMiddleware to the end of middlewares list. + + Current middleware will be used as a most inner of already existed ones. + """ + self._middlewares = (*self._middlewares, middleware) + + for sub in self._subscribers.values(): + sub.add_middleware(middleware) + + for pub in self._publishers.values(): + pub.add_middleware(middleware) + @abstractmethod def subscriber( self, @@ -94,10 +107,10 @@ def include_router( *middlewares, *h._broker_middlewares, ) - h._broker_dependecies = ( + h._broker_dependencies = ( *self._dependencies, *dependencies, - *h._broker_dependecies, + *h._broker_dependencies, ) self._subscribers = {**self._subscribers, key: h} diff --git a/faststream/broker/core/usecase.py b/faststream/broker/core/usecase.py index 81bb903f43..c226850ace 100644 --- a/faststream/broker/core/usecase.py +++ b/faststream/broker/core/usecase.py @@ -42,7 +42,7 @@ from faststream.asyncapi.schema import Tag, TagDict from faststream.broker.publisher.proto import ProducerProto, PublisherProto from faststream.security import BaseSecurity - from faststream.types import AnyDict, AsyncFunc, Decorator, LoggerProto + from faststream.types import AnyDict, Decorator, LoggerProto class BrokerUsecase( @@ -214,7 +214,20 @@ async def start(self) -> None: """Start the broker async use case.""" self._abc_start() await self.connect() + + async def connect(self, **kwargs: Any) -> ConnectionType: + """Connect to a remote server.""" + if self._connection is None: + connection_kwargs = self._connection_kwargs.copy() + connection_kwargs.update(kwargs) + self._connection = await self._connect(**connection_kwargs) self.setup() + return self._connection + + @abstractmethod + async def _connect(self) -> ConnectionType: + """Connect to a resource.""" + raise NotImplementedError() def setup(self) -> None: """Prepare all Broker entities to startup.""" @@ -230,22 +243,9 @@ def setup_subscriber( **kwargs: Any, ) -> None: """Setup the Subscriber to prepare it to starting.""" - subscriber.setup( - logger=self.logger, - producer=self._producer, - graceful_timeout=self.graceful_timeout, - extra_context={}, - # broker options - broker_parser=self._parser, - broker_decoder=self._decoder, - # dependant args - apply_types=self._is_apply_types, - is_validate=self._is_validate, - _get_dependant=self._get_dependant, - _call_decorators=self._call_decorators, - **self._subscriber_setup_extra, - **kwargs, - ) + data = self._subscriber_setup_extra.copy() + data.update(kwargs) + subscriber.setup(**data) def setup_publisher( self, @@ -253,19 +253,32 @@ def setup_publisher( **kwargs: Any, ) -> None: """Setup the Publisher to prepare it to starting.""" - publisher.setup( - producer=self._producer, - **self._publisher_setup_extra, - **kwargs, - ) + data = self._publisher_setup_extra.copy() + data.update(kwargs) + publisher.setup(**data) @property def _subscriber_setup_extra(self) -> "AnyDict": - return {} + return { + "logger": self.logger, + "producer": self._producer, + "graceful_timeout": self.graceful_timeout, + "extra_context": {}, + # broker options + "broker_parser": self._parser, + "broker_decoder": self._decoder, + # dependant args + "apply_types": self._is_apply_types, + "is_validate": self._is_validate, + "_get_dependant": self._get_dependant, + "_call_decorators": self._call_decorators, + } @property def _publisher_setup_extra(self) -> "AnyDict": - return {} + return { + "producer": self._producer, + } def publisher(self, *args: Any, **kwargs: Any) -> "PublisherProto[MsgType]": pub = super().publisher(*args, **kwargs) @@ -288,19 +301,6 @@ def _abc_start(self) -> None: self._get_fmt(), ) - async def connect(self, **kwargs: Any) -> ConnectionType: - """Connect to a remote server.""" - if self._connection is None: - connection_kwargs = self._connection_kwargs.copy() - connection_kwargs.update(kwargs) - self._connection = await self._connect(**connection_kwargs) - return self._connection - - @abstractmethod - async def _connect(self) -> ConnectionType: - """Connect to a resource.""" - raise NotImplementedError() - async def close( self, exc_type: Optional[Type[BaseException]] = None, @@ -334,9 +334,10 @@ async def publish( **kwargs: Any, ) -> Optional[Any]: """Publish message directly.""" - assert producer, NOT_CONNECTED_YET # nosec B101) + assert producer, NOT_CONNECTED_YET # nosec B101 + + publish = producer.publish - publish: "AsyncFunc" = producer.publish for m in self._middlewares: publish = partial(m(None).publish_scope, publish) diff --git a/faststream/broker/message.py b/faststream/broker/message.py index 3f6cef306a..beec9fe555 100644 --- a/faststream/broker/message.py +++ b/faststream/broker/message.py @@ -1,15 +1,18 @@ import json from contextlib import suppress from dataclasses import dataclass, field +from inspect import Parameter from typing import ( TYPE_CHECKING, Any, Generic, + List, Optional, Sequence, Tuple, TypeVar, Union, + cast, ) from uuid import uuid4 @@ -36,6 +39,7 @@ class StreamMessage(Generic[MsgType]): body: Union[bytes, Any] headers: "AnyDict" = field(default_factory=dict) + batch_headers: List["AnyDict"] = field(default_factory=list) path: "AnyDict" = field(default_factory=dict) content_type: Optional[str] = None @@ -64,16 +68,23 @@ def decode_message(message: "StreamMessage[Any]") -> "DecodedMessage": body: Any = getattr(message, "body", message) m: "DecodedMessage" = body - if content_type := getattr(message, "content_type", None): - if ContentTypes.text.value in content_type: + if ( + content_type := getattr(message, "content_type", Parameter.empty) + ) is not Parameter.empty: + content_type = cast(Optional[str], content_type) + + if not content_type: + with suppress(json.JSONDecodeError, UnicodeDecodeError): + m = json_loads(body) + + elif ContentTypes.text.value in content_type: m = body.decode() - elif ContentTypes.json.value in content_type: # pragma: no branch + + elif ContentTypes.json.value in content_type: m = json_loads(body) - else: - with suppress(json.JSONDecodeError): - m = json_loads(body) + else: - with suppress(json.JSONDecodeError): + with suppress(json.JSONDecodeError, UnicodeDecodeError): m = json_loads(body) return m diff --git a/faststream/broker/publisher/proto.py b/faststream/broker/publisher/proto.py index 2233739252..747b29b048 100644 --- a/faststream/broker/publisher/proto.py +++ b/faststream/broker/publisher/proto.py @@ -56,6 +56,9 @@ class PublisherProto( _middlewares: Iterable["PublisherMiddleware"] _producer: Optional["ProducerProto"] + @abstractmethod + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: ... + @staticmethod @abstractmethod def create() -> "PublisherProto[MsgType]": diff --git a/faststream/broker/publisher/usecase.py b/faststream/broker/publisher/usecase.py index 23e8c5586e..46bb96ef2a 100644 --- a/faststream/broker/publisher/usecase.py +++ b/faststream/broker/publisher/usecase.py @@ -19,7 +19,12 @@ from faststream.asyncapi.message import get_response_schema from faststream.asyncapi.utils import to_camelcase from faststream.broker.publisher.proto import PublisherProto -from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn +from faststream.broker.types import ( + BrokerMiddleware, + MsgType, + P_HandlerParams, + T_HandlerReturn, +) from faststream.broker.wrapper.call import HandlerCallWrapper if TYPE_CHECKING: @@ -87,6 +92,9 @@ def __init__( self.include_in_schema = include_in_schema self.schema_ = schema_ + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + self._broker_middlewares = (*self._broker_middlewares, middleware) + @override def setup( # type: ignore[override] self, diff --git a/faststream/broker/subscriber/proto.py b/faststream/broker/subscriber/proto.py index 41667126a4..47bd42b44d 100644 --- a/faststream/broker/subscriber/proto.py +++ b/faststream/broker/subscriber/proto.py @@ -31,10 +31,13 @@ class SubscriberProto( calls: List["HandlerItem[MsgType]"] running: bool - _broker_dependecies: Iterable["Depends"] + _broker_dependencies: Iterable["Depends"] _broker_middlewares: Iterable["BrokerMiddleware[MsgType]"] _producer: Optional["ProducerProto"] + @abstractmethod + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: ... + @abstractmethod def get_log_context( self, diff --git a/faststream/broker/subscriber/usecase.py b/faststream/broker/subscriber/usecase.py index 82e6ebce8c..a2e9d1aa58 100644 --- a/faststream/broker/subscriber/usecase.py +++ b/faststream/broker/subscriber/usecase.py @@ -86,7 +86,7 @@ class SubscriberUsecase( extra_context: "AnyDict" graceful_timeout: Optional[float] - _broker_dependecies: Iterable["Depends"] + _broker_dependencies: Iterable["Depends"] _call_options: Optional["_CallOptions"] def __init__( @@ -117,7 +117,7 @@ def __init__( self.lock = sync_fake_context() # Setup in include - self._broker_dependecies = broker_dependencies + self._broker_dependencies = broker_dependencies self._broker_middlewares = broker_middlewares # register in setup later @@ -131,6 +131,9 @@ def __init__( self.description_ = description_ self.include_in_schema = include_in_schema + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + self._broker_middlewares = (*self._broker_middlewares, middleware) + @override def setup( # type: ignore[override] self, @@ -138,7 +141,7 @@ def setup( # type: ignore[override] logger: Optional["LoggerProto"], producer: Optional[ProducerProto], graceful_timeout: Optional[float], - extra_context: Optional["AnyDict"], + extra_context: "AnyDict", # broker options broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], @@ -152,7 +155,7 @@ def setup( # type: ignore[override] self._producer = producer self.graceful_timeout = graceful_timeout - self.extra_context = extra_context or {} + self.extra_context = extra_context self.watcher = get_watcher_context(logger, self._no_ack, self._retry) @@ -178,7 +181,7 @@ def setup( # type: ignore[override] is_validate=is_validate, _get_dependant=_get_dependant, _call_decorators=_call_decorators, - broker_dependencies=self._broker_dependecies, + broker_dependencies=self._broker_dependencies, ) call.handler.refresh(with_mock=False) diff --git a/faststream/broker/utils.py b/faststream/broker/utils.py index 8ca0585a4c..6903f4c94d 100644 --- a/faststream/broker/utils.py +++ b/faststream/broker/utils.py @@ -50,6 +50,7 @@ def get_watcher_context( return partial( WatcherContext, watcher=get_watcher(logger, retry), + logger=logger, **extra_options, ) diff --git a/faststream/broker/wrapper/call.py b/faststream/broker/wrapper/call.py index f991e1b749..2dda3bf1ea 100644 --- a/faststream/broker/wrapper/call.py +++ b/faststream/broker/wrapper/call.py @@ -1,5 +1,4 @@ import asyncio -from functools import wraps from typing import ( TYPE_CHECKING, Any, @@ -186,13 +185,9 @@ def set_wrapped( def _wrap_decode_message( func: Callable[..., Awaitable[T_HandlerReturn]], params_ln: int, -) -> Callable[ - ["StreamMessage[MsgType]"], - Awaitable[T_HandlerReturn], -]: +) -> Callable[["StreamMessage[MsgType]"], Awaitable[T_HandlerReturn]]: """Wraps a function to decode a message and pass it as an argument to the wrapped function.""" - @wraps(func) async def decode_wrapper(message: "StreamMessage[MsgType]") -> T_HandlerReturn: """A wrapper function to decode and handle a message.""" msg = message.decoded_body diff --git a/faststream/cli/docs/app.py b/faststream/cli/docs/app.py index 4751222a17..c8066a8b9e 100644 --- a/faststream/cli/docs/app.py +++ b/faststream/cli/docs/app.py @@ -44,6 +44,9 @@ def serve( " Defaults to the current working directory." ), ), + is_factory: bool = typer.Option( + False, "--factory", help="Treat APP as an application factory" + ), ) -> None: """Serve project AsyncAPI schema.""" if ":" in app: @@ -66,18 +69,18 @@ def serve( except ImportError: warnings.warn(INSTALL_WATCHFILES, category=ImportWarning, stacklevel=1) - _parse_and_serve(app, host, port) + _parse_and_serve(app, host, port, is_factory) else: WatchReloader( target=_parse_and_serve, - args=(app, host, port), + args=(app, host, port, is_factory), reload_dirs=(str(module_parent),), extra_extensions=extra_extensions, ).run() else: - _parse_and_serve(app, host, port) + _parse_and_serve(app, host, port, is_factory) @docs_app.command(name="gen") @@ -104,12 +107,19 @@ def gen( " Defaults to the current working directory." ), ), + is_factory: bool = typer.Option( + False, + "--factory", + help="Treat APP as an application factory", + ), ) -> None: """Generate project AsyncAPI schema.""" if app_dir: # pragma: no branch sys.path.insert(0, app_dir) _, app_obj = import_from_string(app) + if callable(app_obj) and is_factory: + app_obj = app_obj() raw_schema = get_app_schema(app_obj) if yaml: @@ -138,9 +148,12 @@ def _parse_and_serve( app: str, host: str = "localhost", port: int = 8000, + is_factory: bool = False, ) -> None: if ":" in app: _, app_obj = import_from_string(app) + if callable(app_obj) and is_factory: + app_obj = app_obj() raw_schema = get_app_schema(app_obj) else: diff --git a/faststream/cli/main.py b/faststream/cli/main.py index 7c0ec0391f..bbbe99aa33 100644 --- a/faststream/cli/main.py +++ b/faststream/cli/main.py @@ -94,6 +94,12 @@ def run( " Defaults to the current working directory." ), ), + is_factory: bool = typer.Option( + False, + "--factory", + is_flag=True, + help="Treat APP as an application factory", + ), ) -> None: """Run [MODULE:APP] FastStream application.""" if watch_extensions and not reload: @@ -108,7 +114,7 @@ def run( if app_dir: # pragma: no branch sys.path.insert(0, app_dir) - args = (app, extra, casted_log_level) + args = (app, extra, is_factory, casted_log_level) if reload and workers > 1: raise SetupError("You can't use reload option with multiprocessing") @@ -151,11 +157,14 @@ def _run( # NOTE: we should pass `str` due FastStream is not picklable app: str, extra_options: Dict[str, "SettingField"], + is_factory: bool, log_level: int = logging.INFO, app_level: int = logging.INFO, ) -> None: """Runs the specified application.""" _, app_obj = import_from_string(app) + if is_factory and callable(app_obj): + app_obj = app_obj() if not isinstance(app_obj, FastStream): raise typer.BadParameter( @@ -200,6 +209,12 @@ def publish( app: str = typer.Argument(..., help="FastStream app instance, e.g., main:app"), message: str = typer.Argument(..., help="Message to be published"), rpc: bool = typer.Option(False, help="Enable RPC mode and system output"), + is_factory: bool = typer.Option( + False, + "--factory", + is_flag=True, + help="Treat APP as an application factory", + ), ) -> None: """Publish a message using the specified broker in a FastStream application. @@ -218,6 +233,9 @@ def publish( raise ValueError("Message parameter is required.") _, app_obj = import_from_string(app) + if callable(app_obj) and is_factory: + app_obj = app_obj() + if not app_obj.broker: raise ValueError("Broker instance not found in the app.") diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index a462bb2257..30c97ae298 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -23,7 +23,11 @@ from faststream.broker.message import gen_cor_id from faststream.confluent.broker.logging import KafkaLoggingBroker from faststream.confluent.broker.registrator import KafkaRegistrator -from faststream.confluent.client import AsyncConfluentProducer, _missing +from faststream.confluent.client import ( + AsyncConfluentConsumer, + AsyncConfluentProducer, + _missing, +) from faststream.confluent.publisher.producer import AsyncConfluentFastProducer from faststream.confluent.schemas.params import ConsumerConnectionParams from faststream.confluent.security import parse_security @@ -425,7 +429,7 @@ async def connect( Doc("Kafka addresses to connect."), ] = Parameter.empty, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., AsyncConfluentConsumer]: if bootstrap_servers is not Parameter.empty: kwargs["bootstrap_servers"] = bootstrap_servers @@ -437,17 +441,23 @@ async def _connect( # type: ignore[override] *, client_id: str, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., AsyncConfluentConsumer]: security_params = parse_security(self.security) + kwargs.update(security_params) + producer = AsyncConfluentProducer( **kwargs, - **security_params, client_id=client_id, ) + self._producer = AsyncConfluentFastProducer( producer=producer, ) - return filter_by_dict(ConsumerConnectionParams, {**kwargs, **security_params}) + + return partial( + AsyncConfluentConsumer, + **filter_by_dict(ConsumerConnectionParams, kwargs), + ) async def start(self) -> None: await super().start() @@ -461,7 +471,11 @@ async def start(self) -> None: @property def _subscriber_setup_extra(self) -> "AnyDict": - return {"client_id": self.client_id, "connection_data": self._connection or {}} + return { + **super()._subscriber_setup_extra, + "client_id": self.client_id, + "builder": self._connection, + } @override async def publish( # type: ignore[override] diff --git a/faststream/confluent/broker/logging.py b/faststream/confluent/broker/logging.py index 9eebc89461..4fead65305 100644 --- a/faststream/confluent/broker/logging.py +++ b/faststream/confluent/broker/logging.py @@ -1,9 +1,9 @@ import logging from inspect import Parameter -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union from faststream.broker.core.usecase import BrokerUsecase -from faststream.confluent.schemas.params import ConsumerConnectionParams +from faststream.confluent.client import AsyncConfluentConsumer from faststream.log.logging import get_broker_logger if TYPE_CHECKING: @@ -15,7 +15,7 @@ class KafkaLoggingBroker( BrokerUsecase[ Union["confluent_kafka.Message", Tuple["confluent_kafka.Message", ...]], - ConsumerConnectionParams, + Callable[..., AsyncConfluentConsumer], ] ): """A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information.""" diff --git a/faststream/confluent/broker/registrator.py b/faststream/confluent/broker/registrator.py index 25c8a3128e..6d71a21046 100644 --- a/faststream/confluent/broker/registrator.py +++ b/faststream/confluent/broker/registrator.py @@ -1,4 +1,3 @@ -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -18,7 +17,6 @@ from faststream.broker.core.abc import ABCBroker from faststream.broker.utils import default_filter -from faststream.confluent.client import AsyncConfluentConsumer from faststream.confluent.publisher.asyncapi import AsyncAPIPublisher from faststream.confluent.subscriber.factory import create_subscriber from faststream.exceptions import SetupError @@ -1235,29 +1233,6 @@ def subscriber( if not auto_commit and not group_id: raise SetupError("You should install `group_id` with manual commit mode") - builder = partial( - AsyncConfluentConsumer, - key_deserializer=key_deserializer, - value_deserializer=value_deserializer, - fetch_max_wait_ms=fetch_max_wait_ms, - fetch_max_bytes=fetch_max_bytes, - fetch_min_bytes=fetch_min_bytes, - max_partition_fetch_bytes=max_partition_fetch_bytes, - auto_offset_reset=auto_offset_reset, - enable_auto_commit=auto_commit, - auto_commit_interval_ms=auto_commit_interval_ms, - check_crcs=check_crcs, - partition_assignment_strategy=partition_assignment_strategy, - max_poll_interval_ms=max_poll_interval_ms, - rebalance_timeout_ms=rebalance_timeout_ms, - session_timeout_ms=session_timeout_ms, - heartbeat_interval_ms=heartbeat_interval_ms, - consumer_timeout_ms=consumer_timeout_ms, - max_poll_records=max_poll_records, - exclude_internal_topics=exclude_internal_topics, - isolation_level=isolation_level, - ) - subscriber = super().subscriber( create_subscriber( *topics, @@ -1265,7 +1240,27 @@ def subscriber( batch_timeout_ms=batch_timeout_ms, max_records=max_records, group_id=group_id, - builder=builder, + connection_data={ + "key_deserializer": key_deserializer, + "value_deserializer": value_deserializer, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "enable_auto_commit": auto_commit, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "rebalance_timeout_ms": rebalance_timeout_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "consumer_timeout_ms": consumer_timeout_ms, + "max_poll_records": max_poll_records, + "exclude_internal_topics": exclude_internal_topics, + "isolation_level": isolation_level, + }, is_manual=not auto_commit, # subscriber args no_ack=no_ack, diff --git a/faststream/confluent/opentelemetry/__init__.py b/faststream/confluent/opentelemetry/__init__.py new file mode 100644 index 0000000000..eb3bbafc74 --- /dev/null +++ b/faststream/confluent/opentelemetry/__init__.py @@ -0,0 +1,3 @@ +from faststream.confluent.opentelemetry.middleware import KafkaTelemetryMiddleware + +__all__ = ("KafkaTelemetryMiddleware",) diff --git a/faststream/confluent/opentelemetry/middleware.py b/faststream/confluent/opentelemetry/middleware.py new file mode 100644 index 0000000000..d8e5906dd3 --- /dev/null +++ b/faststream/confluent/opentelemetry/middleware.py @@ -0,0 +1,26 @@ +from typing import Optional + +from opentelemetry.metrics import Meter, MeterProvider +from opentelemetry.trace import TracerProvider + +from faststream.confluent.opentelemetry.provider import ( + telemetry_attributes_provider_factory, +) +from faststream.opentelemetry.middleware import TelemetryMiddleware + + +class KafkaTelemetryMiddleware(TelemetryMiddleware): + def __init__( + self, + *, + tracer_provider: Optional[TracerProvider] = None, + meter_provider: Optional[MeterProvider] = None, + meter: Optional[Meter] = None, + ) -> None: + super().__init__( + settings_provider_factory=telemetry_attributes_provider_factory, + tracer_provider=tracer_provider, + meter_provider=meter_provider, + meter=meter, + include_messages_counters=True, + ) diff --git a/faststream/confluent/opentelemetry/provider.py b/faststream/confluent/opentelemetry/provider.py new file mode 100644 index 0000000000..6add7330ca --- /dev/null +++ b/faststream/confluent/opentelemetry/provider.py @@ -0,0 +1,114 @@ +from typing import TYPE_CHECKING, Sequence, Tuple, Union, cast + +from opentelemetry.semconv.trace import SpanAttributes + +from faststream.broker.types import MsgType +from faststream.opentelemetry import TelemetrySettingsProvider +from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME + +if TYPE_CHECKING: + from confluent_kafka import Message + + from faststream.broker.message import StreamMessage + from faststream.types import AnyDict + + +class BaseConfluentTelemetrySettingsProvider(TelemetrySettingsProvider[MsgType]): + __slots__ = ("messaging_system",) + + def __init__(self) -> None: + self.messaging_system = "kafka" + + def get_publish_attrs_from_kwargs( + self, + kwargs: "AnyDict", + ) -> "AnyDict": + attrs = { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_DESTINATION_NAME: kwargs["topic"], + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: kwargs["correlation_id"], + } + + if (partition := kwargs.get("partition")) is not None: + attrs[SpanAttributes.MESSAGING_KAFKA_DESTINATION_PARTITION] = partition + + if (key := kwargs.get("key")) is not None: + attrs[SpanAttributes.MESSAGING_KAFKA_MESSAGE_KEY] = key + + return attrs + + @staticmethod + def get_publish_destination_name( + kwargs: "AnyDict", + ) -> str: + return cast(str, kwargs["topic"]) + + +class ConfluentTelemetrySettingsProvider( + BaseConfluentTelemetrySettingsProvider["Message"] +): + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[Message]", + ) -> "AnyDict": + attrs = { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_MESSAGE_ID: msg.message_id, + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: msg.correlation_id, + SpanAttributes.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES: len(msg.body), + SpanAttributes.MESSAGING_KAFKA_DESTINATION_PARTITION: msg.raw_message.partition(), + SpanAttributes.MESSAGING_KAFKA_MESSAGE_OFFSET: msg.raw_message.offset(), + MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message.topic(), + } + + if (key := msg.raw_message.key()) is not None: + attrs[SpanAttributes.MESSAGING_KAFKA_MESSAGE_KEY] = key + + return attrs + + @staticmethod + def get_consume_destination_name( + msg: "StreamMessage[Message]", + ) -> str: + return cast(str, msg.raw_message.topic()) + + +class BatchConfluentTelemetrySettingsProvider( + BaseConfluentTelemetrySettingsProvider[Tuple["Message", ...]] +): + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[Tuple[Message, ...]]", + ) -> "AnyDict": + raw_message = msg.raw_message[0] + attrs = { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_MESSAGE_ID: msg.message_id, + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: msg.correlation_id, + SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: len(msg.raw_message), + SpanAttributes.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES: len( + bytearray().join(cast(Sequence[bytes], msg.body)) + ), + SpanAttributes.MESSAGING_KAFKA_DESTINATION_PARTITION: raw_message.partition(), + MESSAGING_DESTINATION_PUBLISH_NAME: raw_message.topic(), + } + + return attrs + + @staticmethod + def get_consume_destination_name( + msg: "StreamMessage[Tuple[Message, ...]]", + ) -> str: + return cast(str, msg.raw_message[0].topic()) + + +def telemetry_attributes_provider_factory( + msg: Union["Message", Sequence["Message"], None], +) -> Union[ + ConfluentTelemetrySettingsProvider, + BatchConfluentTelemetrySettingsProvider, +]: + if isinstance(msg, Sequence): + return BatchConfluentTelemetrySettingsProvider() + else: + return ConfluentTelemetrySettingsProvider() diff --git a/faststream/confluent/parser.py b/faststream/confluent/parser.py index a4858247ac..e743c96e6b 100644 --- a/faststream/confluent/parser.py +++ b/faststream/confluent/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union from faststream.broker.message import decode_message, gen_cor_id from faststream.confluent.message import FAKE_CONSUMER, KafkaMessage @@ -20,18 +20,14 @@ async def parse_message( message: "Message", ) -> "StreamMessage[Message]": """Parses a Kafka message.""" - headers = {} - if message.headers() is not None: - for i, j in message.headers(): # type: ignore[union-attr] - if isinstance(j, str): - headers[i] = j - else: - headers[i] = j.decode() + headers = _parse_msg_headers(message.headers()) + body = message.value() offset = message.offset() _, timestamp = message.timestamp() handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return KafkaMessage( body=body, headers=headers, @@ -49,28 +45,29 @@ async def parse_message_batch( message: Tuple["Message", ...], ) -> "StreamMessage[Tuple[Message, ...]]": """Parses a batch of messages from a Kafka consumer.""" + body: List[Any] = [] + batch_headers: List[Dict[str, str]] = [] + first = message[0] last = message[-1] - headers = {} - if first.headers() is not None: - for i, j in first.headers(): # type: ignore[union-attr] - if isinstance(j, str): - headers[i] = j - else: - headers[i] = j.decode() - body = [m.value() for m in message] - first_offset = first.offset() - last_offset = last.offset() + for m in message: + body.append(m.value()) + batch_headers.append(_parse_msg_headers(m.headers())) + + headers = next(iter(batch_headers), {}) + _, first_timestamp = first.timestamp() handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return KafkaMessage( body=body, headers=headers, + batch_headers=batch_headers, reply_to=headers.get("reply_to", ""), content_type=headers.get("content-type"), - message_id=f"{first_offset}-{last_offset}-{first_timestamp}", + message_id=f"{first.offset()}-{last.offset()}-{first_timestamp}", correlation_id=headers.get("correlation_id", gen_cor_id()), raw_message=message, consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER, @@ -91,3 +88,9 @@ async def decode_message_batch( ) -> "DecodedMessage": """Decode a batch of messages.""" return [decode_message(await cls.parse_message(m)) for m in msg.raw_message] + + +def _parse_msg_headers( + headers: Sequence[Tuple[str, Union[bytes, str]]], +) -> Dict[str, str]: + return {i: j if isinstance(j, str) else j.decode() for i, j in headers} diff --git a/faststream/confluent/router.py b/faststream/confluent/router.py index 98830127ef..33480a12ea 100644 --- a/faststream/confluent/router.py +++ b/faststream/confluent/router.py @@ -1,6 +1,7 @@ from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Iterable, @@ -129,7 +130,10 @@ class KafkaRoute(SubscriberRoute): def __init__( self, call: Annotated[ - Callable[..., "SendableMessage"], + Union[ + Callable[..., "SendableMessage"], + Callable[..., Awaitable["SendableMessage"]], + ], Doc("Message handler function."), ], *topics: Annotated[ @@ -468,13 +472,13 @@ def __init__( class KafkaRouter( + KafkaRegistrator, BrokerRouter[ Union[ "Message", Tuple["Message", ...], ] ], - KafkaRegistrator, ): """Includable to KafkaBroker router.""" diff --git a/faststream/confluent/subscriber/asyncapi.py b/faststream/confluent/subscriber/asyncapi.py index 7ec3ffb965..8eacb40ee8 100644 --- a/faststream/confluent/subscriber/asyncapi.py +++ b/faststream/confluent/subscriber/asyncapi.py @@ -22,6 +22,9 @@ if TYPE_CHECKING: from confluent_kafka import Message as ConfluentMsg + from fast_depends.dependencies import Depends + from faststream.broker.types import BrokerMiddleware + from faststream.types import AnyDict class AsyncAPISubscriber(LogicSubscriber[MsgType]): @@ -56,7 +59,6 @@ def get_schema(self) -> Dict[str, Channel]: return channels - class AsyncAPIDefaultSubscriber( DefaultSubscriber, AsyncAPISubscriber["ConfluentMsg"], diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index d778086bae..28f7ece4e7 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -19,7 +19,6 @@ from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import MsgType from faststream.confluent.parser import AsyncConfluentParser -from faststream.confluent.schemas.params import ConsumerConnectionParams if TYPE_CHECKING: from fast_depends.dependencies import Depends @@ -41,7 +40,9 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): topics: Sequence[str] group_id: Optional[str] + builder: Optional[Callable[..., "AsyncConfluentConsumer"]] consumer: Optional["AsyncConfluentConsumer"] + task: Optional["asyncio.Task[None]"] client_id: Optional[str] @@ -50,7 +51,7 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args default_parser: "AsyncCallable", @@ -81,25 +82,25 @@ def __init__( self.group_id = group_id self.topics = topics self.is_manual = is_manual - self.builder = builder + self.builder = None self.consumer = None self.task = None # Setup it later self.client_id = "" - self.__connection_data = ConsumerConnectionParams() + self.__connection_data = connection_data @override def setup( # type: ignore[override] self, *, client_id: Optional[str], - connection_data: "ConsumerConnectionParams", + builder: Callable[..., "AsyncConfluentConsumer"], # basic args logger: Optional["LoggerProto"], producer: Optional["ProducerProto"], graceful_timeout: Optional[float], - extra_context: Optional["AnyDict"], + extra_context: "AnyDict", # broker options broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], @@ -110,7 +111,7 @@ def setup( # type: ignore[override] _call_decorators: Iterable["Decorator"], ) -> None: self.client_id = client_id - self.__connection_data = connection_data + self.builder = builder super().setup( logger=logger, @@ -128,6 +129,8 @@ def setup( # type: ignore[override] @override async def start(self) -> None: """Start the consumer.""" + assert self.builder, "You should setup subscriber at first." # nosec B101 + self.consumer = consumer = self.builder( *self.topics, group_id=self.group_id, @@ -172,7 +175,7 @@ async def get_msg(self) -> Optional[MsgType]: raise NotImplementedError() async def _consume(self) -> None: - assert self.consumer, "You need to start handler first" # nosec B101 + assert self.consumer, "You should start subscriber at first." # nosec B101 connected = True while self.running: @@ -219,7 +222,7 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -234,7 +237,7 @@ def __init__( super().__init__( *topics, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, # subscriber args default_parser=AsyncConfluentParser.parse_message, @@ -278,7 +281,7 @@ def __init__( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -296,7 +299,7 @@ def __init__( super().__init__( *topics, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, # subscriber args default_parser=AsyncConfluentParser.parse_message_batch, diff --git a/faststream/confluent/testing.py b/faststream/confluent/testing.py index 4559cbde8b..9420ff3aa5 100644 --- a/faststream/confluent/testing.py +++ b/faststream/confluent/testing.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import AsyncMock, MagicMock from typing_extensions import override @@ -22,8 +23,13 @@ class TestKafkaBroker(TestBroker[KafkaBroker]): """A class to test Kafka brokers.""" @staticmethod - async def _fake_connect(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: KafkaBroker, + *args: Any, + **kwargs: Any, + ) -> Callable[..., AsyncMock]: broker._producer = FakeProducer(broker) + return _fake_connection @staticmethod def create_publisher_fake_subscriber( @@ -231,3 +237,10 @@ def build_message( timestamp_type=0 + 1, timestamp_ms=timestamp_ms or int(datetime.now().timestamp()), ) + + +def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: + mock = AsyncMock() + mock.getone.return_value = MagicMock() + mock.getmany.return_value = [MagicMock()] + return mock diff --git a/faststream/kafka/__init__.py b/faststream/kafka/__init__.py index eb83bd8b01..c81b617033 100644 --- a/faststream/kafka/__init__.py +++ b/faststream/kafka/__init__.py @@ -1,3 +1,5 @@ +from aiokafka import TopicPartition + from faststream.kafka.annotations import KafkaMessage from faststream.kafka.broker import KafkaBroker from faststream.kafka.router import KafkaPublisher, KafkaRoute, KafkaRouter @@ -12,4 +14,5 @@ "KafkaPublisher", "TestKafkaBroker", "TestApp", + "TopicPartition", ) diff --git a/faststream/kafka/broker/broker.py b/faststream/kafka/broker/broker.py index 16df9c7c8c..de0b6980f1 100644 --- a/faststream/kafka/broker/broker.py +++ b/faststream/kafka/broker/broker.py @@ -534,6 +534,7 @@ def __init__( apply_types=apply_types, validate=validate, ) + self.client_id = client_id self._producer = None @@ -557,7 +558,7 @@ async def connect( # type: ignore[override] Doc("Kafka addresses to connect."), ] = Parameter.empty, **kwargs: "Unpack[KafkaInitKwargs]", - ) -> ConsumerConnectionParams: + ) -> Callable[..., aiokafka.AIOKafkaConsumer]: """Connect to Kafka servers manually. Consumes the same with `KafkaBroker.__init__` arguments and overrides them. @@ -579,18 +580,24 @@ async def _connect( # type: ignore[override] *, client_id: str, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., aiokafka.AIOKafkaConsumer]: security_params = parse_security(self.security) + kwargs.update(security_params) + producer = aiokafka.AIOKafkaProducer( **kwargs, - **security_params, client_id=client_id, ) + await producer.start() self._producer = AioKafkaFastProducer( producer=producer, ) - return filter_by_dict(ConsumerConnectionParams, {**kwargs, **security_params}) + + return partial( + aiokafka.AIOKafkaConsumer, + **filter_by_dict(ConsumerConnectionParams, kwargs), + ) async def start(self) -> None: """Connect broker to Kafka and startup all subscribers.""" @@ -606,8 +613,9 @@ async def start(self) -> None: @property def _subscriber_setup_extra(self) -> "AnyDict": return { + **super()._subscriber_setup_extra, "client_id": self.client_id, - "connection_args": self._connection or {}, + "builder": self._connection, } @override diff --git a/faststream/kafka/broker/logging.py b/faststream/kafka/broker/logging.py index df828024da..16b1103b83 100644 --- a/faststream/kafka/broker/logging.py +++ b/faststream/kafka/broker/logging.py @@ -1,9 +1,8 @@ import logging from inspect import Parameter -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union from faststream.broker.core.usecase import BrokerUsecase -from faststream.kafka.schemas.params import ConsumerConnectionParams from faststream.log.logging import get_broker_logger if TYPE_CHECKING: @@ -15,7 +14,7 @@ class KafkaLoggingBroker( BrokerUsecase[ Union["aiokafka.ConsumerRecord", Tuple["aiokafka.ConsumerRecord", ...]], - ConsumerConnectionParams, + Callable[..., "aiokafka.AIOKafkaConsumer"], ] ): """A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information.""" diff --git a/faststream/kafka/broker/registrator.py b/faststream/kafka/broker/registrator.py index c574e46614..0633032c06 100644 --- a/faststream/kafka/broker/registrator.py +++ b/faststream/kafka/broker/registrator.py @@ -1,4 +1,3 @@ -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -14,18 +13,17 @@ overload, ) -from aiokafka import AIOKafkaConsumer, ConsumerRecord +from aiokafka import ConsumerRecord from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from typing_extensions import Annotated, Doc, deprecated, override from faststream.broker.core.abc import ABCBroker from faststream.broker.utils import default_filter -from faststream.exceptions import SetupError from faststream.kafka.publisher.asyncapi import AsyncAPIPublisher from faststream.kafka.subscriber.factory import create_subscriber if TYPE_CHECKING: - from aiokafka import ConsumerRecord + from aiokafka import ConsumerRecord, TopicPartition from aiokafka.abc import ConsumerRebalanceListener from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from fast_depends.dependencies import Depends @@ -336,6 +334,13 @@ def subscriber( Pattern to match available topics. You must provide either topics or pattern, but not both. """), ] = None, + partitions: Annotated[ + Iterable["TopicPartition"], + Doc(""" + An explicit partitions list to assign. + You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["Depends"], @@ -660,6 +665,13 @@ def subscriber( Pattern to match available topics. You must provide either topics or pattern, but not both. """), ] = None, + partitions: Annotated[ + Iterable["TopicPartition"], + Doc(""" + An explicit partitions list to assign. + You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["Depends"], @@ -984,6 +996,13 @@ def subscriber( Pattern to match available topics. You must provide either topics or pattern, but not both. """), ] = None, + partitions: Annotated[ + Iterable["TopicPartition"], + Doc(""" + An explicit partitions list to assign. + You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["Depends"], @@ -1311,6 +1330,13 @@ def subscriber( Pattern to match available topics. You must provide either topics or pattern, but not both. """), ] = None, + partitions: Annotated[ + Iterable["TopicPartition"], + Doc(""" + An explicit partitions list to assign. + You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["Depends"], @@ -1367,32 +1393,6 @@ def subscriber( "AsyncAPIDefaultSubscriber", "AsyncAPIBatchSubscriber", ]: - if not auto_commit and not group_id: - raise SetupError("You should install `group_id` with manual commit mode") - - builder = partial( - AIOKafkaConsumer, - key_deserializer=key_deserializer, - value_deserializer=value_deserializer, - fetch_max_wait_ms=fetch_max_wait_ms, - fetch_max_bytes=fetch_max_bytes, - fetch_min_bytes=fetch_min_bytes, - max_partition_fetch_bytes=max_partition_fetch_bytes, - auto_offset_reset=auto_offset_reset, - enable_auto_commit=auto_commit, - auto_commit_interval_ms=auto_commit_interval_ms, - check_crcs=check_crcs, - partition_assignment_strategy=partition_assignment_strategy, - max_poll_interval_ms=max_poll_interval_ms, - rebalance_timeout_ms=rebalance_timeout_ms, - session_timeout_ms=session_timeout_ms, - heartbeat_interval_ms=heartbeat_interval_ms, - consumer_timeout_ms=consumer_timeout_ms, - max_poll_records=max_poll_records, - exclude_internal_topics=exclude_internal_topics, - isolation_level=isolation_level, - ) - subscriber = super().subscriber( create_subscriber( *topics, @@ -1402,7 +1402,28 @@ def subscriber( group_id=group_id, listener=listener, pattern=pattern, - builder=builder, + connection_args={ + "key_deserializer": key_deserializer, + "value_deserializer": value_deserializer, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "enable_auto_commit": auto_commit, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "rebalance_timeout_ms": rebalance_timeout_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "consumer_timeout_ms": consumer_timeout_ms, + "max_poll_records": max_poll_records, + "exclude_internal_topics": exclude_internal_topics, + "isolation_level": isolation_level, + }, + partitions=partitions, is_manual=not auto_commit, # subscriber args no_ack=no_ack, diff --git a/faststream/kafka/fastapi/fastapi.py b/faststream/kafka/fastapi/fastapi.py index ce988aa329..541940d79e 100644 --- a/faststream/kafka/fastapi/fastapi.py +++ b/faststream/kafka/fastapi/fastapi.py @@ -38,6 +38,7 @@ from asyncio import AbstractEventLoop from enum import Enum + from aiokafka import TopicPartition from aiokafka.abc import AbstractTokenProvider, ConsumerRebalanceListener from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from fastapi import params @@ -919,6 +920,13 @@ def subscriber( Pattern to match available topics. You must provide either topics or pattern, but not both. """), ] = None, + partitions: Annotated[ + Iterable["TopicPartition"], + Doc(""" + An explicit partitions list to assign. + You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["params.Depends"], @@ -1401,6 +1409,13 @@ def subscriber( Pattern to match available topics. You must provide either topics or pattern, but not both. """), ] = None, + partitions: Annotated[ + Iterable["TopicPartition"], + Doc(""" + An explicit partitions list to assign. + You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["params.Depends"], @@ -1883,6 +1898,13 @@ def subscriber( Pattern to match available topics. You must provide either topics or pattern, but not both. """), ] = None, + partitions: Annotated[ + Iterable["TopicPartition"], + Doc(""" + An explicit partitions list to assign. + You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["params.Depends"], @@ -2368,6 +2390,13 @@ def subscriber( Pattern to match available topics. You must provide either topics or pattern, but not both. """), ] = None, + partitions: Annotated[ + Iterable["TopicPartition"], + Doc(""" + An explicit partitions list to assign. + You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["params.Depends"], @@ -2575,6 +2604,7 @@ def subscriber( batch_timeout_ms=batch_timeout_ms, listener=listener, pattern=pattern, + partitions=partitions, # broker args dependencies=dependencies, parser=parser, diff --git a/faststream/kafka/opentelemetry/__init__.py b/faststream/kafka/opentelemetry/__init__.py new file mode 100644 index 0000000000..6bd75f272c --- /dev/null +++ b/faststream/kafka/opentelemetry/__init__.py @@ -0,0 +1,3 @@ +from faststream.kafka.opentelemetry.middleware import KafkaTelemetryMiddleware + +__all__ = ("KafkaTelemetryMiddleware",) diff --git a/faststream/kafka/opentelemetry/middleware.py b/faststream/kafka/opentelemetry/middleware.py new file mode 100644 index 0000000000..2f06486c33 --- /dev/null +++ b/faststream/kafka/opentelemetry/middleware.py @@ -0,0 +1,26 @@ +from typing import Optional + +from opentelemetry.metrics import Meter, MeterProvider +from opentelemetry.trace import TracerProvider + +from faststream.kafka.opentelemetry.provider import ( + telemetry_attributes_provider_factory, +) +from faststream.opentelemetry.middleware import TelemetryMiddleware + + +class KafkaTelemetryMiddleware(TelemetryMiddleware): + def __init__( + self, + *, + tracer_provider: Optional[TracerProvider] = None, + meter_provider: Optional[MeterProvider] = None, + meter: Optional[Meter] = None, + ) -> None: + super().__init__( + settings_provider_factory=telemetry_attributes_provider_factory, + tracer_provider=tracer_provider, + meter_provider=meter_provider, + meter=meter, + include_messages_counters=True, + ) diff --git a/faststream/kafka/opentelemetry/provider.py b/faststream/kafka/opentelemetry/provider.py new file mode 100644 index 0000000000..b1702b6022 --- /dev/null +++ b/faststream/kafka/opentelemetry/provider.py @@ -0,0 +1,115 @@ +from typing import TYPE_CHECKING, Sequence, Tuple, Union, cast + +from opentelemetry.semconv.trace import SpanAttributes + +from faststream.broker.types import MsgType +from faststream.opentelemetry import TelemetrySettingsProvider +from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME + +if TYPE_CHECKING: + from aiokafka import ConsumerRecord + + from faststream.broker.message import StreamMessage + from faststream.types import AnyDict + + +class BaseKafkaTelemetrySettingsProvider(TelemetrySettingsProvider[MsgType]): + __slots__ = ("messaging_system",) + + def __init__(self) -> None: + self.messaging_system = "kafka" + + def get_publish_attrs_from_kwargs( + self, + kwargs: "AnyDict", + ) -> "AnyDict": + attrs = { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_DESTINATION_NAME: kwargs["topic"], + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: kwargs["correlation_id"], + } + + if (partition := kwargs.get("partition")) is not None: + attrs[SpanAttributes.MESSAGING_KAFKA_DESTINATION_PARTITION] = partition + + if (key := kwargs.get("key")) is not None: + attrs[SpanAttributes.MESSAGING_KAFKA_MESSAGE_KEY] = key + + return attrs + + @staticmethod + def get_publish_destination_name( + kwargs: "AnyDict", + ) -> str: + return cast(str, kwargs["topic"]) + + +class KafkaTelemetrySettingsProvider( + BaseKafkaTelemetrySettingsProvider["ConsumerRecord"] +): + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[ConsumerRecord]", + ) -> "AnyDict": + attrs = { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_MESSAGE_ID: msg.message_id, + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: msg.correlation_id, + SpanAttributes.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES: len(msg.body), + SpanAttributes.MESSAGING_KAFKA_DESTINATION_PARTITION: msg.raw_message.partition, + SpanAttributes.MESSAGING_KAFKA_MESSAGE_OFFSET: msg.raw_message.offset, + MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message.topic, + } + + if msg.raw_message.key is not None: + attrs[SpanAttributes.MESSAGING_KAFKA_MESSAGE_KEY] = msg.raw_message.key + + return attrs + + @staticmethod + def get_consume_destination_name( + msg: "StreamMessage[ConsumerRecord]", + ) -> str: + return cast(str, msg.raw_message.topic) + + +class BatchKafkaTelemetrySettingsProvider( + BaseKafkaTelemetrySettingsProvider[Tuple["ConsumerRecord", ...]] +): + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[Tuple[ConsumerRecord, ...]]", + ) -> "AnyDict": + raw_message = msg.raw_message[0] + + attrs = { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_MESSAGE_ID: msg.message_id, + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: msg.correlation_id, + SpanAttributes.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES: len( + bytearray().join(cast(Sequence[bytes], msg.body)) + ), + SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: len(msg.raw_message), + SpanAttributes.MESSAGING_KAFKA_DESTINATION_PARTITION: raw_message.partition, + MESSAGING_DESTINATION_PUBLISH_NAME: raw_message.topic, + } + + return attrs + + @staticmethod + def get_consume_destination_name( + msg: "StreamMessage[Tuple[ConsumerRecord, ...]]", + ) -> str: + return cast(str, msg.raw_message[0].topic) + + +def telemetry_attributes_provider_factory( + msg: Union["ConsumerRecord", Sequence["ConsumerRecord"], None], +) -> Union[ + KafkaTelemetrySettingsProvider, + BatchKafkaTelemetrySettingsProvider, +]: + if isinstance(msg, Sequence): + return BatchKafkaTelemetrySettingsProvider() + else: + return KafkaTelemetrySettingsProvider() diff --git a/faststream/kafka/parser.py b/faststream/kafka/parser.py index c99bc31c33..8487eb3d0b 100644 --- a/faststream/kafka/parser.py +++ b/faststream/kafka/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from faststream.broker.message import decode_message, gen_cor_id from faststream.kafka.message import FAKE_CONSUMER, KafkaMessage @@ -39,13 +39,24 @@ async def parse_message_batch( message: Tuple["ConsumerRecord", ...], ) -> "StreamMessage[Tuple[ConsumerRecord, ...]]": """Parses a batch of messages from a Kafka consumer.""" + body: List[Any] = [] + batch_headers: List[Dict[str, str]] = [] + first = message[0] last = message[-1] - headers = {i: j.decode() for i, j in first.headers} + + for m in message: + body.append(m.value) + batch_headers.append({i: j.decode() for i, j in m.headers}) + + headers = next(iter(batch_headers), {}) + handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return KafkaMessage( - body=[m.value for m in message], + body=body, headers=headers, + batch_headers=batch_headers, reply_to=headers.get("reply_to", ""), content_type=headers.get("content-type"), message_id=f"{first.offset}-{last.offset}-{first.timestamp}", diff --git a/faststream/kafka/router.py b/faststream/kafka/router.py index 9640bebf65..44540ee4d5 100644 --- a/faststream/kafka/router.py +++ b/faststream/kafka/router.py @@ -1,6 +1,7 @@ from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Iterable, @@ -19,7 +20,8 @@ from faststream.kafka.broker.registrator import KafkaRegistrator if TYPE_CHECKING: - from aiokafka import ConsumerRecord + from aiokafka import ConsumerRecord, TopicPartition + from aiokafka.abc import ConsumerRebalanceListener from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from fast_depends.dependencies import Depends @@ -131,7 +133,10 @@ class KafkaRoute(SubscriberRoute): def __init__( self, call: Annotated[ - Callable[..., "SendableMessage"], + Union[ + Callable[..., "SendableMessage"], + Callable[..., Awaitable["SendableMessage"]], + ], Doc( "Message handler function " "to wrap the same with `@broker.subscriber(...)` way." @@ -376,6 +381,44 @@ def __init__( Optional[int], Doc("Number of messages to consume as one batch."), ] = None, + listener: Annotated[ + Optional["ConsumerRebalanceListener"], + Doc(""" + Optionally include listener + callback, which will be called before and after each rebalance + operation. + As part of group management, the consumer will keep track of + the list of consumers that belong to a particular group and + will trigger a rebalance operation if one of the following + events trigger: + + * Number of partitions change for any of the subscribed topics + * Topic is created or deleted + * An existing member of the consumer group dies + * A new member is added to the consumer group + + When any of these events are triggered, the provided listener + will be invoked first to indicate that the consumer's + assignment has been revoked, and then again when the new + assignment has been received. Note that this listener will + immediately override any listener set in a previous call + to subscribe. It is guaranteed, however, that the partitions + revoked/assigned + through this interface are from topics subscribed in this call. + """), + ] = None, + pattern: Annotated[ + Optional[str], + Doc(""" + Pattern to match available topics. You must provide either topics or pattern, but not both. + """), + ] = None, + partitions: Annotated[ + Optional[Iterable["TopicPartition"]], + Doc(""" + A topic and partition tuple. You can't use 'topics' and 'partitions' in the same time. + """), + ] = (), # broker args dependencies: Annotated[ Iterable["Depends"], @@ -456,6 +499,9 @@ def __init__( max_records=max_records, batch_timeout_ms=batch_timeout_ms, batch=batch, + listener=listener, + pattern=pattern, + partitions=partitions, # basic args dependencies=dependencies, parser=parser, @@ -473,13 +519,13 @@ def __init__( class KafkaRouter( + KafkaRegistrator, BrokerRouter[ Union[ "ConsumerRecord", Tuple["ConsumerRecord", ...], ] ], - KafkaRegistrator, ): """Includable to KafkaBroker router.""" diff --git a/faststream/kafka/subscriber/asyncapi.py b/faststream/kafka/subscriber/asyncapi.py index 9adb8dad3c..cef8f9e11c 100644 --- a/faststream/kafka/subscriber/asyncapi.py +++ b/faststream/kafka/subscriber/asyncapi.py @@ -14,6 +14,7 @@ from faststream.asyncapi.schema.bindings import kafka from faststream.asyncapi.utils import resolve_payloads from faststream.broker.types import MsgType +from faststream.exceptions import SetupError from faststream.kafka.subscriber.usecase import ( BatchSubscriber, DefaultSubscriber, @@ -22,6 +23,12 @@ if TYPE_CHECKING: from aiokafka import ConsumerRecord + from aiokafka import ConsumerRecord, TopicPartition + from aiokafka.abc import ConsumerRebalanceListener + from fast_depends.dependencies import Depends + + from faststream.broker.types import BrokerMiddleware + from faststream.types import AnyDict class AsyncAPISubscriber(LogicSubscriber[MsgType]): @@ -56,7 +63,6 @@ def get_schema(self) -> Dict[str, Channel]: return channels - class AsyncAPIDefaultSubscriber( DefaultSubscriber, AsyncAPISubscriber["ConsumerRecord"], diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 0a99702b98..650bae75d1 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -7,12 +7,14 @@ Callable, Dict, Iterable, + List, Optional, Sequence, Tuple, ) import anyio +from aiokafka import TopicPartition from aiokafka.errors import ConsumerStoppedError, KafkaError from typing_extensions import override @@ -33,7 +35,6 @@ from faststream.broker.message import StreamMessage from faststream.broker.publisher.proto import ProducerProto - from faststream.kafka.schemas.params import ConsumerConnectionParams from faststream.types import AnyDict, Decorator, LoggerProto @@ -43,7 +44,9 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): topics: Sequence[str] group_id: Optional[str] + builder: Optional[Callable[..., "AIOKafkaConsumer"]] consumer: Optional["AIOKafkaConsumer"] + task: Optional["asyncio.Task[None]"] client_id: Optional[str] batch: bool @@ -53,9 +56,10 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], + partitions: Iterable["TopicPartition"], is_manual: bool, # Subscriber args default_parser: "AsyncCallable", @@ -83,10 +87,12 @@ def __init__( include_in_schema=include_in_schema, ) - self.group_id = group_id self.topics = topics + self.partitions = partitions + self.group_id = group_id + self.is_manual = is_manual - self.builder = builder + self.builder = None self.consumer = None self.task = None @@ -94,19 +100,19 @@ def __init__( self.client_id = "" self.__pattern = pattern self.__listener = listener - self.__connection_args: "ConsumerConnectionParams" = {} + self.__connection_args = connection_args @override def setup( # type: ignore[override] self, *, client_id: Optional[str], - connection_args: "ConsumerConnectionParams", + builder: Callable[..., "AIOKafkaConsumer"], # basic args logger: Optional["LoggerProto"], producer: Optional["ProducerProto"], graceful_timeout: Optional[float], - extra_context: Optional["AnyDict"], + extra_context: "AnyDict", # broker options broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], @@ -117,7 +123,7 @@ def setup( # type: ignore[override] _call_decorators: Iterable["Decorator"], ) -> None: self.client_id = client_id - self.__connection_args = connection_args + self.builder = builder super().setup( logger=logger, @@ -134,18 +140,25 @@ def setup( # type: ignore[override] async def start(self) -> None: """Start the consumer.""" + assert self.builder, "You should setup subscriber at first." # nosec B101 + self.consumer = consumer = self.builder( group_id=self.group_id, client_id=self.client_id, **self.__connection_args, ) - consumer.subscribe( - topics=self.topics, - pattern=self.__pattern, - listener=self.__listener, - ) - await consumer.start() + if self.topics: + consumer.subscribe( + topics=self.topics, + pattern=self.__pattern, + listener=self.__listener, + ) + + elif self.partitions: + consumer.assign(partitions=self.partitions) + + await consumer.start() await super().start() self.task = asyncio.create_task(self._consume()) @@ -183,7 +196,7 @@ async def get_msg(self) -> MsgType: raise NotImplementedError() async def _consume(self) -> None: - assert self.consumer, "You should setup subscriber at first." # nosec B101 + assert self.consumer, "You should start subscriber at first." # nosec B101 connected = True while self.running: @@ -213,9 +226,18 @@ def get_routing_hash( ) -> int: return hash("".join((*topics, group_id or ""))) + @property + def topic_names(self) -> List[str]: + if self.__pattern: + return [self.__pattern] + elif self.topics: + return list(self.topics) + else: + return [f"{p.topic}-{p.partition}" for p in self.partitions] + def __hash__(self) -> int: return self.get_routing_hash( - topics=(*self.topics, self.__pattern or ""), + topics=self.topic_names, group_id=self.group_id, ) @@ -236,7 +258,7 @@ def get_log_context( message: Optional["StreamMessage[ConsumerRecord]"], ) -> Dict[str, str]: if message is None: - topic = ",".join(self.topics) + topic = ",".join(self.topic_names) elif isinstance(message.raw_message, Sequence): topic = message.raw_message[0].topic else: @@ -251,6 +273,14 @@ def get_log_context( def add_prefix(self, prefix: str) -> None: self.topics = tuple("".join((prefix, t)) for t in self.topics) + self.partitions = [ + TopicPartition( + topic="".join((prefix, p.topic)), + partition=p.partition, + ) + for p in self.partitions + ] + class DefaultSubscriber(LogicSubscriber["ConsumerRecord"]): def __init__( @@ -260,7 +290,8 @@ def __init__( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", + partitions: Iterable["TopicPartition"], is_manual: bool, # Subscriber args no_ack: bool, @@ -277,7 +308,8 @@ def __init__( group_id=group_id, listener=listener, pattern=pattern, - builder=builder, + connection_args=connection_args, + partitions=partitions, is_manual=is_manual, # subscriber args default_parser=AioKafkaParser.parse_message, @@ -308,7 +340,8 @@ def __init__( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", + partitions: Iterable["TopicPartition"], is_manual: bool, # Subscriber args no_ack: bool, @@ -330,7 +363,8 @@ def __init__( group_id=group_id, listener=listener, pattern=pattern, - builder=builder, + connection_args=connection_args, + partitions=partitions, is_manual=is_manual, # subscriber args default_parser=AioKafkaParser.parse_message_batch, diff --git a/faststream/kafka/testing.py b/faststream/kafka/testing.py old mode 100644 new mode 100755 index fb9e71417f..fd8b520332 --- a/faststream/kafka/testing.py +++ b/faststream/kafka/testing.py @@ -1,10 +1,12 @@ from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from unittest.mock import AsyncMock, MagicMock from aiokafka import ConsumerRecord from typing_extensions import override from faststream.broker.message import encode_message, gen_cor_id +from faststream.kafka import TopicPartition from faststream.kafka.broker import KafkaBroker from faststream.kafka.publisher.asyncapi import AsyncAPIBatchPublisher from faststream.kafka.publisher.producer import AioKafkaFastProducer @@ -23,18 +25,30 @@ class TestKafkaBroker(TestBroker[KafkaBroker]): """A class to test Kafka brokers.""" @staticmethod - async def _fake_connect(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: KafkaBroker, + *args: Any, + **kwargs: Any, + ) -> Callable[..., AsyncMock]: broker._producer = FakeProducer(broker) + return _fake_connection @staticmethod def create_publisher_fake_subscriber( broker: KafkaBroker, publisher: "AsyncAPIPublisher[Any]", ) -> "HandlerCallWrapper[Any, Any, Any]": - sub = broker.subscriber( - publisher.topic, - batch=isinstance(publisher, AsyncAPIBatchPublisher), - ) + if publisher.partition: + tp = TopicPartition(topic=publisher.topic, partition=publisher.partition) + sub = broker.subscriber( + partitions=[tp], + batch=isinstance(publisher, AsyncAPIBatchPublisher), + ) + else: + sub = broker.subscriber( + publisher.topic, + batch=isinstance(publisher, AsyncAPIBatchPublisher), + ) if not sub.calls: @@ -92,7 +106,16 @@ async def publish( # type: ignore[override] ) for handler in self.broker._subscribers.values(): # pragma: no branch - if topic in handler.topics: + call: bool = False + + for p in handler.partitions: + if p.topic == topic and (partition is None or p.partition == partition): + call = True + + if not call and topic in handler.topics: + call = True + + if call: return await call_handler( handler=handler, message=[incoming] @@ -184,3 +207,10 @@ def build_message( offset=0, headers=[(i, j.encode()) for i, j in headers.items()], ) + + +def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: + mock = AsyncMock() + mock.subscribe = MagicMock + mock.assign = MagicMock + return mock diff --git a/faststream/nats/broker/broker.py b/faststream/nats/broker/broker.py index a1b7c4f0b9..da4da2b394 100644 --- a/faststream/nats/broker/broker.py +++ b/faststream/nats/broker/broker.py @@ -655,6 +655,28 @@ async def start(self) -> None: else: # pragma: no cover self._log(str(e), logging.ERROR, log_context, exc_info=e) + except BadRequestError as e: + if ( + e.description + == "stream name already in use with a different configuration" + ): + old_config = (await self.stream.stream_info(stream.name)).config + + self._log(str(e), logging.WARNING, log_context) + await self.stream.update_stream( + config=stream.config, + subjects=tuple( + set(old_config.subjects or ()).union(stream.subjects) + ), + ) + + else: # pragma: no cover + self._log(str(e), logging.ERROR, log_context, exc_info=e) + + finally: + # prevent from double declaration + stream.declare = False + # TODO: filter by already running handlers after TestClient refactor for handler in self._subscribers.values(): self._log( @@ -730,7 +752,7 @@ async def publish( # type: ignore[override] Please, use `@broker.publisher(...)` or `broker.publisher(...).publish(...)` instead in a regular way. """ - publihs_kwargs = { + publish_kwargs = { "subject": subject, "headers": headers, "reply_to": reply_to, @@ -745,7 +767,7 @@ async def publish( # type: ignore[override] producer = self._producer else: producer = self._js_producer - publihs_kwargs.update( + publish_kwargs.update( { "stream": stream, "timeout": timeout, @@ -755,7 +777,7 @@ async def publish( # type: ignore[override] return await super().publish( message, producer=producer, - **publihs_kwargs, + **publish_kwargs, ) @override @@ -802,10 +824,7 @@ def setup_publisher( # type: ignore[override] elif self._producer is not None: producer = self._producer - publisher.setup( - producer=producer, - **self._publisher_setup_extra, - ) + super().setup_publisher(publisher, producer=producer) async def key_value( self, diff --git a/faststream/nats/opentelemetry/__init__.py b/faststream/nats/opentelemetry/__init__.py new file mode 100644 index 0000000000..d97f2b5d38 --- /dev/null +++ b/faststream/nats/opentelemetry/__init__.py @@ -0,0 +1,3 @@ +from faststream.nats.opentelemetry.middleware import NatsTelemetryMiddleware + +__all__ = ("NatsTelemetryMiddleware",) diff --git a/faststream/nats/opentelemetry/middleware.py b/faststream/nats/opentelemetry/middleware.py new file mode 100644 index 0000000000..cafd8787d8 --- /dev/null +++ b/faststream/nats/opentelemetry/middleware.py @@ -0,0 +1,24 @@ +from typing import Optional + +from opentelemetry.metrics import Meter, MeterProvider +from opentelemetry.trace import TracerProvider + +from faststream.nats.opentelemetry.provider import telemetry_attributes_provider_factory +from faststream.opentelemetry.middleware import TelemetryMiddleware + + +class NatsTelemetryMiddleware(TelemetryMiddleware): + def __init__( + self, + *, + tracer_provider: Optional[TracerProvider] = None, + meter_provider: Optional[MeterProvider] = None, + meter: Optional[Meter] = None, + ) -> None: + super().__init__( + settings_provider_factory=telemetry_attributes_provider_factory, + tracer_provider=tracer_provider, + meter_provider=meter_provider, + meter=meter, + include_messages_counters=True, + ) diff --git a/faststream/nats/opentelemetry/provider.py b/faststream/nats/opentelemetry/provider.py new file mode 100644 index 0000000000..7aefafed2c --- /dev/null +++ b/faststream/nats/opentelemetry/provider.py @@ -0,0 +1,114 @@ +from typing import TYPE_CHECKING, List, Optional, Sequence, Union, overload + +from opentelemetry.semconv.trace import SpanAttributes + +from faststream.__about__ import SERVICE_NAME +from faststream.broker.types import MsgType +from faststream.opentelemetry import TelemetrySettingsProvider +from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME + +if TYPE_CHECKING: + from nats.aio.msg import Msg + + from faststream.broker.message import StreamMessage + from faststream.types import AnyDict + + +class BaseNatsTelemetrySettingsProvider(TelemetrySettingsProvider[MsgType]): + __slots__ = ("messaging_system",) + + def __init__(self) -> None: + self.messaging_system = "nats" + + def get_publish_attrs_from_kwargs( + self, + kwargs: "AnyDict", + ) -> "AnyDict": + return { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_DESTINATION_NAME: kwargs["subject"], + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: kwargs["correlation_id"], + } + + @staticmethod + def get_publish_destination_name( + kwargs: "AnyDict", + ) -> str: + subject: str = kwargs.get("subject", SERVICE_NAME) + return subject + + +class NatsTelemetrySettingsProvider(BaseNatsTelemetrySettingsProvider["Msg"]): + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[Msg]", + ) -> "AnyDict": + return { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_MESSAGE_ID: msg.message_id, + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: msg.correlation_id, + SpanAttributes.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES: len(msg.body), + MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message.subject, + } + + @staticmethod + def get_consume_destination_name( + msg: "StreamMessage[Msg]", + ) -> str: + return msg.raw_message.subject + + +class NatsBatchTelemetrySettingsProvider( + BaseNatsTelemetrySettingsProvider[List["Msg"]] +): + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[List[Msg]]", + ) -> "AnyDict": + return { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_MESSAGE_ID: msg.message_id, + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: msg.correlation_id, + SpanAttributes.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES: len(msg.body), + SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: len(msg.raw_message), + MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message[0].subject, + } + + @staticmethod + def get_consume_destination_name( + msg: "StreamMessage[List[Msg]]", + ) -> str: + return msg.raw_message[0].subject + + +@overload +def telemetry_attributes_provider_factory( + msg: Optional["Msg"], +) -> NatsTelemetrySettingsProvider: ... + + +@overload +def telemetry_attributes_provider_factory( + msg: Sequence["Msg"], +) -> NatsBatchTelemetrySettingsProvider: ... + + +@overload +def telemetry_attributes_provider_factory( + msg: Union["Msg", Sequence["Msg"], None], +) -> Union[ + NatsTelemetrySettingsProvider, + NatsBatchTelemetrySettingsProvider, +]: ... + + +def telemetry_attributes_provider_factory( + msg: Union["Msg", Sequence["Msg"], None], +) -> Union[ + NatsTelemetrySettingsProvider, + NatsBatchTelemetrySettingsProvider, +]: + if isinstance(msg, Sequence): + return NatsBatchTelemetrySettingsProvider() + else: + return NatsTelemetrySettingsProvider() diff --git a/faststream/nats/parser.py b/faststream/nats/parser.py index 4fea13ee3c..4824d84716 100644 --- a/faststream/nats/parser.py +++ b/faststream/nats/parser.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from faststream.broker.message import StreamMessage, decode_message, gen_cor_id from faststream.nats.message import ( @@ -109,15 +110,27 @@ async def parse_batch( self, message: List["Msg"], ) -> "StreamMessage[List[Msg]]": - if first_msg := next(iter(message), None): - path = self.get_path(first_msg.subject) + body: List[bytes] = [] + batch_headers: List[Dict[str, str]] = [] + + if message: + path = self.get_path(message[0].subject) + + for m in message: + batch_headers.append(m.headers or {}) + body.append(m.data) + else: path = None + headers = next(iter(batch_headers), {}) + return NatsBatchMessage( raw_message=message, - body=[m.data for m in message], + body=body, path=path or {}, + headers=headers, + batch_headers=batch_headers, ) async def decode_batch( diff --git a/faststream/nats/router.py b/faststream/nats/router.py index 3f29af056d..679010773a 100644 --- a/faststream/nats/router.py +++ b/faststream/nats/router.py @@ -1,4 +1,13 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Union, +) from nats.js import api from typing_extensions import Annotated, Doc, deprecated @@ -106,7 +115,10 @@ class NatsRoute(SubscriberRoute): def __init__( self, call: Annotated[ - Callable[..., "SendableMessage"], + Union[ + Callable[..., "SendableMessage"], + Callable[..., Awaitable["SendableMessage"]], + ], Doc( "Message handler function " "to wrap the same with `@broker.subscriber(...)` way." @@ -312,8 +324,8 @@ def __init__( class NatsRouter( - BrokerRouter["Msg"], NatsRegistrator, + BrokerRouter["Msg"], ): """Includable to NatsBroker router.""" diff --git a/faststream/nats/subscriber/asyncapi.py b/faststream/nats/subscriber/asyncapi.py index ad0edb0bca..ce5cfdb6ee 100644 --- a/faststream/nats/subscriber/asyncapi.py +++ b/faststream/nats/subscriber/asyncapi.py @@ -103,10 +103,16 @@ def get_schema(self) -> Dict[str, Channel]: class AsyncAPIObjStoreWatchSubscriber(AsyncAPISubscriber, ObjStoreWatchSubscriber): """ObjStoreWatch consumer with AsyncAPI methods.""" +class AsyncAPIDefaultSubscriber(DefaultHandler, AsyncAPISubscriber): + """One-message consumer with AsyncAPI methods.""" + @override def get_name(self) -> str: return "" - @override def get_schema(self) -> Dict[str, Channel]: return {} + +class AsyncAPIBatchSubscriber(BatchHandler, AsyncAPISubscriber): + """Batch-message consumer with AsyncAPI methods.""" + diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 25e88dc68d..7a0fbf9aa3 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -117,7 +117,7 @@ def setup( # type: ignore[override] logger: Optional["LoggerProto"], producer: Optional["ProducerProto"], graceful_timeout: Optional[float], - extra_context: Optional["AnyDict"], + extra_context: "AnyDict", # broker options broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index 5bde0eb1f1..5a9190dfd7 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -1,9 +1,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from unittest.mock import AsyncMock from nats.aio.msg import Msg from typing_extensions import override from faststream.broker.message import encode_message, gen_cor_id +from faststream.exceptions import WRONG_PUBLISH_ARGS from faststream.nats.broker import NatsBroker from faststream.nats.publisher.producer import NatsFastProducer from faststream.nats.schemas.js_stream import is_subject_match_wildcard @@ -39,8 +41,14 @@ def f(msg: Any) -> None: return sub.calls[0].handler @staticmethod - async def _fake_connect(broker: NatsBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: NatsBroker, + *args: Any, + **kwargs: Any, + ) -> AsyncMock: + broker.stream = AsyncMock() # type: ignore[assignment] broker._js_producer = broker._producer = FakeProducer(broker) # type: ignore[assignment] + return AsyncMock() @staticmethod def remove_publisher_fake_subscriber( @@ -71,6 +79,9 @@ async def publish( # type: ignore[override] rpc_timeout: Optional[float] = None, raise_timeout: bool = False, ) -> Any: + if rpc and reply_to: + raise WRONG_PUBLISH_ARGS + incoming = build_message( message=message, subject=subject, diff --git a/faststream/opentelemetry/__init__.py b/faststream/opentelemetry/__init__.py new file mode 100644 index 0000000000..401c1be077 --- /dev/null +++ b/faststream/opentelemetry/__init__.py @@ -0,0 +1,7 @@ +from faststream.opentelemetry.middleware import TelemetryMiddleware +from faststream.opentelemetry.provider import TelemetrySettingsProvider + +__all__ = ( + "TelemetryMiddleware", + "TelemetrySettingsProvider", +) diff --git a/faststream/opentelemetry/consts.py b/faststream/opentelemetry/consts.py new file mode 100644 index 0000000000..2436d568ee --- /dev/null +++ b/faststream/opentelemetry/consts.py @@ -0,0 +1,9 @@ +class MessageAction: + CREATE = "create" + PUBLISH = "publish" + PROCESS = "process" + RECEIVE = "receive" + + +ERROR_TYPE = "error.type" +MESSAGING_DESTINATION_PUBLISH_NAME = "messaging.destination_publish.name" diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py new file mode 100644 index 0000000000..9a4ad34c10 --- /dev/null +++ b/faststream/opentelemetry/middleware.py @@ -0,0 +1,299 @@ +import time +from copy import copy +from typing import TYPE_CHECKING, Any, Callable, Optional, Type + +from opentelemetry import context, metrics, propagate, trace +from opentelemetry.semconv.trace import SpanAttributes + +from faststream import BaseMiddleware +from faststream.opentelemetry.consts import ( + ERROR_TYPE, + MESSAGING_DESTINATION_PUBLISH_NAME, + MessageAction, +) +from faststream.opentelemetry.provider import TelemetrySettingsProvider + +if TYPE_CHECKING: + from types import TracebackType + + from opentelemetry.context import Context + from opentelemetry.metrics import Meter, MeterProvider + from opentelemetry.trace import Span, Tracer, TracerProvider + + from faststream.broker.message import StreamMessage + from faststream.types import AnyDict, AsyncFunc, AsyncFuncAny + + +_OTEL_SCHEMA = "https://opentelemetry.io/schemas/1.11.0" + + +def _create_span_name(destination: str, action: str) -> str: + return f"{destination} {action}" + + +class _MetricsContainer: + __slots__ = ( + "include_messages_counters", + "publish_duration", + "publish_counter", + "process_duration", + "process_counter", + ) + + def __init__(self, meter: "Meter", include_messages_counters: bool) -> None: + self.include_messages_counters = include_messages_counters + + self.publish_duration = meter.create_histogram( + name="messaging.publish.duration", + unit="s", + description="Measures the duration of publish operation.", + ) + self.process_duration = meter.create_histogram( + name="messaging.process.duration", + unit="s", + description="Measures the duration of process operation.", + ) + + if include_messages_counters: + self.process_counter = meter.create_counter( + name="messaging.process.messages", + unit="message", + description="Measures the number of processed messages.", + ) + self.publish_counter = meter.create_counter( + name="messaging.publish.messages", + unit="message", + description="Measures the number of published messages.", + ) + + def observe_publish( + self, attrs: "AnyDict", duration: float, msg_count: int + ) -> None: + self.publish_duration.record( + amount=duration, + attributes=attrs, + ) + if self.include_messages_counters: + counter_attrs = copy(attrs) + counter_attrs.pop(ERROR_TYPE, None) + self.publish_counter.add( + amount=msg_count, + attributes=counter_attrs, + ) + + def observe_consume( + self, attrs: "AnyDict", duration: float, msg_count: int + ) -> None: + self.process_duration.record( + amount=duration, + attributes=attrs, + ) + if self.include_messages_counters: + counter_attrs = copy(attrs) + counter_attrs.pop(ERROR_TYPE, None) + self.process_counter.add( + amount=msg_count, + attributes=counter_attrs, + ) + + +class BaseTelemetryMiddleware(BaseMiddleware): + def __init__( + self, + *, + tracer: "Tracer", + settings_provider_factory: Callable[[Any], TelemetrySettingsProvider[Any]], + metrics_container: _MetricsContainer, + msg: Optional[Any] = None, + ) -> None: + self.msg = msg + + self._tracer = tracer + self._metrics = metrics_container + self._current_span: Optional[Span] = None + self._origin_context: Optional[Context] = None + self.__settings_provider = settings_provider_factory(msg) + + async def publish_scope( + self, + call_next: "AsyncFunc", + msg: Any, + *args: Any, + **kwargs: Any, + ) -> Any: + provider = self.__settings_provider + + headers = kwargs.pop("headers", {}) or {} + current_context = context.get_current() + destination_name = provider.get_publish_destination_name(kwargs) + + trace_attributes = provider.get_publish_attrs_from_kwargs(kwargs) + metrics_attributes = { + SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system, + SpanAttributes.MESSAGING_DESTINATION_NAME: destination_name, + } + + # NOTE: if batch with single message? + if (msg_count := len((msg, *args))) > 1: + trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count + + if self._current_span and self._current_span.is_recording(): + current_context = trace.set_span_in_context( + self._current_span, current_context + ) + propagate.inject(headers, context=self._origin_context) + + else: + create_span = self._tracer.start_span( + name=_create_span_name(destination_name, MessageAction.CREATE), + kind=trace.SpanKind.PRODUCER, + attributes=trace_attributes, + ) + current_context = trace.set_span_in_context(create_span) + propagate.inject(headers, context=current_context) + create_span.end() + + start_time = time.perf_counter() + + try: + with self._tracer.start_as_current_span( + name=_create_span_name(destination_name, MessageAction.PUBLISH), + kind=trace.SpanKind.PRODUCER, + attributes=trace_attributes, + context=current_context, + ) as span: + span.set_attribute( + SpanAttributes.MESSAGING_OPERATION, MessageAction.PUBLISH + ) + result = await call_next(msg, *args, headers=headers, **kwargs) + + except Exception as e: + metrics_attributes[ERROR_TYPE] = type(e).__name__ + raise + + finally: + duration = time.perf_counter() - start_time + self._metrics.observe_publish(metrics_attributes, duration, msg_count) + + return result + + async def consume_scope( + self, + call_next: "AsyncFuncAny", + msg: "StreamMessage[Any]", + ) -> Any: + provider = self.__settings_provider + + current_context = propagate.extract(msg.headers) + destination_name = provider.get_consume_destination_name(msg) + + trace_attributes = provider.get_consume_attrs_from_message(msg) + metrics_attributes = { + SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system, + MESSAGING_DESTINATION_PUBLISH_NAME: destination_name, + } + + if not len(current_context): + create_span = self._tracer.start_span( + name=_create_span_name(destination_name, MessageAction.CREATE), + kind=trace.SpanKind.CONSUMER, + attributes=trace_attributes, + ) + current_context = trace.set_span_in_context(create_span) + create_span.end() + + self._origin_context = current_context + start_time = time.perf_counter() + + try: + with self._tracer.start_as_current_span( + name=_create_span_name(destination_name, MessageAction.PROCESS), + kind=trace.SpanKind.CONSUMER, + context=current_context, + attributes=trace_attributes, + end_on_exit=False, + ) as span: + span.set_attribute( + SpanAttributes.MESSAGING_OPERATION, MessageAction.PROCESS + ) + self._current_span = span + new_context = trace.set_span_in_context(span, current_context) + token = context.attach(new_context) + result = await call_next(msg) + context.detach(token) + + except Exception as e: + metrics_attributes[ERROR_TYPE] = type(e).__name__ + raise + + finally: + duration = time.perf_counter() - start_time + msg_count = trace_attributes.get( + SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT, 1 + ) + self._metrics.observe_consume(metrics_attributes, duration, msg_count) + + return result + + async def after_processed( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exc_tb: Optional["TracebackType"] = None, + ) -> Optional[bool]: + if self._current_span and self._current_span.is_recording(): + self._current_span.end() + return False + + +class TelemetryMiddleware: + # NOTE: should it be class or function? + __slots__ = ( + "_tracer", + "_meter", + "_metrics", + "_settings_provider_factory", + ) + + def __init__( + self, + *, + settings_provider_factory: Callable[[Any], TelemetrySettingsProvider[Any]], + tracer_provider: Optional["TracerProvider"] = None, + meter_provider: Optional["MeterProvider"] = None, + meter: Optional["Meter"] = None, + include_messages_counters: bool = False, + ) -> None: + self._tracer = _get_tracer(tracer_provider) + self._meter = _get_meter(meter_provider, meter) + self._metrics = _MetricsContainer(self._meter, include_messages_counters) + self._settings_provider_factory = settings_provider_factory + + def __call__(self, msg: Optional[Any]) -> BaseMiddleware: + return BaseTelemetryMiddleware( + tracer=self._tracer, + metrics_container=self._metrics, + settings_provider_factory=self._settings_provider_factory, + msg=msg, + ) + + +def _get_meter( + meter_provider: Optional["MeterProvider"] = None, + meter: Optional["Meter"] = None, +) -> "Meter": + if meter is None: + return metrics.get_meter( + __name__, + meter_provider=meter_provider, + schema_url=_OTEL_SCHEMA, + ) + return meter + + +def _get_tracer(tracer_provider: Optional["TracerProvider"] = None) -> "Tracer": + return trace.get_tracer( + __name__, + tracer_provider=tracer_provider, + schema_url=_OTEL_SCHEMA, + ) diff --git a/faststream/opentelemetry/provider.py b/faststream/opentelemetry/provider.py new file mode 100644 index 0000000000..90232d45ab --- /dev/null +++ b/faststream/opentelemetry/provider.py @@ -0,0 +1,31 @@ +from typing import TYPE_CHECKING, Protocol + +from faststream.broker.types import MsgType + +if TYPE_CHECKING: + from faststream.broker.message import StreamMessage + from faststream.types import AnyDict + + +class TelemetrySettingsProvider(Protocol[MsgType]): + messaging_system: str + + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[MsgType]", + ) -> "AnyDict": ... + + def get_consume_destination_name( + self, + msg: "StreamMessage[MsgType]", + ) -> str: ... + + def get_publish_attrs_from_kwargs( + self, + kwargs: "AnyDict", + ) -> "AnyDict": ... + + def get_publish_destination_name( + self, + kwargs: "AnyDict", + ) -> str: ... diff --git a/faststream/rabbit/__init__.py b/faststream/rabbit/__init__.py index 11ca1a9373..7c05cb70c8 100644 --- a/faststream/rabbit/__init__.py +++ b/faststream/rabbit/__init__.py @@ -21,5 +21,6 @@ "ReplyConfig", "RabbitExchange", "RabbitQueue", + # Annotations "RabbitMessage", ) diff --git a/faststream/rabbit/annotations.py b/faststream/rabbit/annotations.py index bfb78c6af9..f32654d2cc 100644 --- a/faststream/rabbit/annotations.py +++ b/faststream/rabbit/annotations.py @@ -1,3 +1,4 @@ +from aio_pika import RobustChannel, RobustConnection from typing_extensions import Annotated from faststream.annotations import ContextRepo, Logger, NoCast @@ -13,8 +14,20 @@ "RabbitMessage", "RabbitBroker", "RabbitProducer", + "Channel", + "Connection", ) RabbitMessage = Annotated[RM, Context("message")] RabbitBroker = Annotated[RB, Context("broker")] RabbitProducer = Annotated[AioPikaFastProducer, Context("broker._producer")] + +Channel = Annotated[RobustChannel, Context("broker._channel")] +Connection = Annotated[RobustConnection, Context("broker._connection")] + +# NOTE: transaction is not for the public usage yet +# async def _get_transaction(connection: Connection) -> RabbitTransaction: +# async with connection.channel(publisher_confirms=False) as channel: +# yield channel.transaction() + +# Transaction = Annotated[RabbitTransaction, Depends(_get_transaction)] diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index fd4ca30d84..f7ec134f86 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -100,6 +100,26 @@ def __init__( "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Optional[int], + Doc("Specify the channel number explicit."), + ] = None, + publisher_confirms: Annotated[ + bool, + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = True, + on_return_raises: Annotated[ + bool, + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ), + ] = False, # broker args max_consumers: Annotated[ Optional[int], @@ -220,6 +240,10 @@ def __init__( url=str(amqp_url), ssl_context=security_args.get("ssl_context"), timeout=timeout, + # channel args + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, # Basic args graceful_timeout=graceful_timeout, dependencies=dependencies, @@ -254,6 +278,7 @@ def __init__( @property def _subscriber_setup_extra(self) -> "AnyDict": return { + **super()._subscriber_setup_extra, "app_id": self.app_id, "virtual_host": self.virtual_host, "declarer": self.declarer, @@ -262,6 +287,7 @@ def _subscriber_setup_extra(self) -> "AnyDict": @property def _publisher_setup_extra(self) -> "AnyDict": return { + **super()._publisher_setup_extra, "app_id": self.app_id, "virtual_host": self.virtual_host, } @@ -303,6 +329,26 @@ async def connect( # type: ignore[override] "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Union[int, None, object], + Doc("Specify the channel number explicit."), + ] = Parameter.empty, + publisher_confirms: Annotated[ + Union[bool, object], + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = Parameter.empty, + on_return_raises: Annotated[ + Union[bool, object], + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ), + ] = Parameter.empty, ) -> "RobustConnection": """Connect broker object to RabbitMQ. @@ -310,6 +356,15 @@ async def connect( # type: ignore[override] """ kwargs: AnyDict = {} + if channel_number is not Parameter.empty: + kwargs["channel_number"] = channel_number + + if publisher_confirms is not Parameter.empty: + kwargs["publisher_confirms"] = publisher_confirms + + if on_return_raises is not Parameter.empty: + kwargs["on_return_raises"] = on_return_raises + if timeout: kwargs["timeout"] = timeout @@ -346,6 +401,9 @@ async def _connect( # type: ignore[override] *, timeout: "TimeoutType", ssl_context: Optional["SSLContext"], + channel_number: Optional[int], + publisher_confirms: bool, + on_return_raises: bool, ) -> "RobustConnection": connection = cast( "RobustConnection", @@ -360,7 +418,11 @@ async def _connect( # type: ignore[override] max_consumers = self._max_consumers channel = self._channel = cast( "RobustChannel", - await connection.channel(), + await connection.channel( + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, + ), ) declarer = self.declarer = RabbitDeclarer(channel) diff --git a/faststream/rabbit/fastapi/router.py b/faststream/rabbit/fastapi/router.py index 4cc90b25d9..6d13beabae 100644 --- a/faststream/rabbit/fastapi/router.py +++ b/faststream/rabbit/fastapi/router.py @@ -96,6 +96,26 @@ def __init__( "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Optional[int], + Doc("Specify the channel number explicit."), + ] = None, + publisher_confirms: Annotated[ + bool, + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = True, + on_return_raises: Annotated[ + bool, + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ), + ] = False, # broker args max_consumers: Annotated[ Optional[int], @@ -408,6 +428,9 @@ def __init__( graceful_timeout=graceful_timeout, decoder=decoder, parser=parser, + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, middlewares=middlewares, security=security, asyncapi_url=asyncapi_url, diff --git a/faststream/rabbit/opentelemetry/__init__.py b/faststream/rabbit/opentelemetry/__init__.py new file mode 100644 index 0000000000..f850b09125 --- /dev/null +++ b/faststream/rabbit/opentelemetry/__init__.py @@ -0,0 +1,3 @@ +from faststream.rabbit.opentelemetry.middleware import RabbitTelemetryMiddleware + +__all__ = ("RabbitTelemetryMiddleware",) diff --git a/faststream/rabbit/opentelemetry/middleware.py b/faststream/rabbit/opentelemetry/middleware.py new file mode 100644 index 0000000000..29a553a7f0 --- /dev/null +++ b/faststream/rabbit/opentelemetry/middleware.py @@ -0,0 +1,24 @@ +from typing import Optional + +from opentelemetry.metrics import Meter, MeterProvider +from opentelemetry.trace import TracerProvider + +from faststream.opentelemetry.middleware import TelemetryMiddleware +from faststream.rabbit.opentelemetry.provider import RabbitTelemetrySettingsProvider + + +class RabbitTelemetryMiddleware(TelemetryMiddleware): + def __init__( + self, + *, + tracer_provider: Optional[TracerProvider] = None, + meter_provider: Optional[MeterProvider] = None, + meter: Optional[Meter] = None, + ) -> None: + super().__init__( + settings_provider_factory=lambda _: RabbitTelemetrySettingsProvider(), + tracer_provider=tracer_provider, + meter_provider=meter_provider, + meter=meter, + include_messages_counters=False, + ) diff --git a/faststream/rabbit/opentelemetry/provider.py b/faststream/rabbit/opentelemetry/provider.py new file mode 100644 index 0000000000..da62338e70 --- /dev/null +++ b/faststream/rabbit/opentelemetry/provider.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING + +from opentelemetry.semconv.trace import SpanAttributes + +from faststream.opentelemetry import TelemetrySettingsProvider +from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME + +if TYPE_CHECKING: + from aio_pika import IncomingMessage + + from faststream.broker.message import StreamMessage + from faststream.types import AnyDict + + +class RabbitTelemetrySettingsProvider(TelemetrySettingsProvider["IncomingMessage"]): + __slots__ = ("messaging_system",) + + def __init__(self) -> None: + self.messaging_system = "rabbitmq" + + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[IncomingMessage]", + ) -> "AnyDict": + return { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_MESSAGE_ID: msg.message_id, + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: msg.correlation_id, + SpanAttributes.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES: len(msg.body), + SpanAttributes.MESSAGING_RABBITMQ_DESTINATION_ROUTING_KEY: msg.raw_message.routing_key, + "messaging.rabbitmq.message.delivery_tag": msg.raw_message.delivery_tag, + MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message.exchange, + } + + @staticmethod + def get_consume_destination_name( + msg: "StreamMessage[IncomingMessage]", + ) -> str: + exchange = msg.raw_message.exchange or "default" + routing_key = msg.raw_message.routing_key + return f"{exchange}.{routing_key}" + + def get_publish_attrs_from_kwargs( + self, + kwargs: "AnyDict", + ) -> "AnyDict": + return { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_DESTINATION_NAME: kwargs.get("exchange") or "", + SpanAttributes.MESSAGING_RABBITMQ_DESTINATION_ROUTING_KEY: kwargs[ + "routing_key" + ], + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: kwargs["correlation_id"], + } + + @staticmethod + def get_publish_destination_name( + kwargs: "AnyDict", + ) -> str: + exchange: str = kwargs.get("exchange") or "default" + routing_key: str = kwargs["routing_key"] + return f"{exchange}.{routing_key}" diff --git a/faststream/rabbit/publisher/usecase.py b/faststream/rabbit/publisher/usecase.py index 505bcf2268..7ac5dc6389 100644 --- a/faststream/rabbit/publisher/usecase.py +++ b/faststream/rabbit/publisher/usecase.py @@ -97,9 +97,10 @@ class LogicPublisher( ): """A class to represent a RabbitMQ publisher.""" - _producer: Optional["AioPikaFastProducer"] app_id: Optional[str] + _producer: Optional["AioPikaFastProducer"] + def __init__( self, *, diff --git a/faststream/rabbit/router.py b/faststream/rabbit/router.py index 98e6438c4c..0890433347 100644 --- a/faststream/rabbit/router.py +++ b/faststream/rabbit/router.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Optional, Union from typing_extensions import Annotated, Doc, deprecated @@ -177,7 +177,10 @@ class RabbitRoute(SubscriberRoute): def __init__( self, call: Annotated[ - Callable[..., "AioPikaSendableMessage"], + Union[ + Callable[..., "AioPikaSendableMessage"], + Callable[..., Awaitable["AioPikaSendableMessage"]], + ], Doc( "Message handler function " "to wrap the same with `@broker.subscriber(...)` way." @@ -285,8 +288,8 @@ def __init__( class RabbitRouter( - BrokerRouter["IncomingMessage"], RabbitRegistrator, + BrokerRouter["IncomingMessage"], ): """Includable to RabbitBroker router.""" diff --git a/faststream/rabbit/schemas/queue.py b/faststream/rabbit/schemas/queue.py index b63685d1a5..a9bccf013d 100644 --- a/faststream/rabbit/schemas/queue.py +++ b/faststream/rabbit/schemas/queue.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import TYPE_CHECKING, Optional from typing_extensions import Annotated, Doc @@ -115,3 +116,13 @@ def __init__( self.auto_delete = auto_delete self.arguments = arguments self.timeout = timeout + + def add_prefix(self, prefix: str) -> "RabbitQueue": + new_q: RabbitQueue = deepcopy(self) + + new_q.name = "".join((prefix, new_q.name)) + + if new_q.routing_key: + new_q.routing_key = "".join((prefix, new_q.routing_key)) + + return new_q diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index aecac22384..c0700dcc82 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import ( TYPE_CHECKING, Any, @@ -105,7 +104,7 @@ def setup( # type: ignore[override] logger: Optional["LoggerProto"], producer: Optional["AioPikaFastProducer"], graceful_timeout: Optional[float], - extra_context: Optional["AnyDict"], + extra_context: "AnyDict", # broker options broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], @@ -223,6 +222,4 @@ def get_log_context( def add_prefix(self, prefix: str) -> None: """Include Subscriber in router.""" - new_q = deepcopy(self.queue) - new_q.name = "".join((prefix, new_q.name)) - self.queue = new_q + self.queue = self.queue.add_prefix(prefix) diff --git a/faststream/rabbit/testing.py b/faststream/rabbit/testing.py index e425ed02d6..e15cbe2cb3 100644 --- a/faststream/rabbit/testing.py +++ b/faststream/rabbit/testing.py @@ -8,6 +8,7 @@ from typing_extensions import override from faststream.broker.message import gen_cor_id +from faststream.exceptions import WRONG_PUBLISH_ARGS from faststream.rabbit.broker.broker import RabbitBroker from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher @@ -197,6 +198,9 @@ async def publish( # type: ignore[override] """Publish a message to a RabbitMQ queue or exchange.""" exch = RabbitExchange.validate(exchange) + if rpc and reply_to: + raise WRONG_PUBLISH_ARGS + incoming = build_message( message=message, exchange=exch, diff --git a/faststream/redis/broker/broker.py b/faststream/redis/broker/broker.py index 7c0cefe09a..3164c7a01b 100644 --- a/faststream/redis/broker/broker.py +++ b/faststream/redis/broker/broker.py @@ -263,7 +263,7 @@ async def connect( # type: ignore[override] **kwargs, } else: - connect_kwargs = {**kwargs} + connect_kwargs = dict(kwargs).copy() return await super().connect(**connect_kwargs) @@ -359,6 +359,7 @@ async def start(self) -> None: @property def _subscriber_setup_extra(self) -> "AnyDict": return { + **super()._subscriber_setup_extra, "connection": self._connection, } diff --git a/faststream/redis/opentelemetry/__init__.py b/faststream/redis/opentelemetry/__init__.py new file mode 100644 index 0000000000..aea6429256 --- /dev/null +++ b/faststream/redis/opentelemetry/__init__.py @@ -0,0 +1,3 @@ +from faststream.redis.opentelemetry.middleware import RedisTelemetryMiddleware + +__all__ = ("RedisTelemetryMiddleware",) diff --git a/faststream/redis/opentelemetry/middleware.py b/faststream/redis/opentelemetry/middleware.py new file mode 100644 index 0000000000..54c0024143 --- /dev/null +++ b/faststream/redis/opentelemetry/middleware.py @@ -0,0 +1,24 @@ +from typing import Optional + +from opentelemetry.metrics import Meter, MeterProvider +from opentelemetry.trace import TracerProvider + +from faststream.opentelemetry.middleware import TelemetryMiddleware +from faststream.redis.opentelemetry.provider import RedisTelemetrySettingsProvider + + +class RedisTelemetryMiddleware(TelemetryMiddleware): + def __init__( + self, + *, + tracer_provider: Optional[TracerProvider] = None, + meter_provider: Optional[MeterProvider] = None, + meter: Optional[Meter] = None, + ) -> None: + super().__init__( + settings_provider_factory=lambda _: RedisTelemetrySettingsProvider(), + tracer_provider=tracer_provider, + meter_provider=meter_provider, + meter=meter, + include_messages_counters=True, + ) diff --git a/faststream/redis/opentelemetry/provider.py b/faststream/redis/opentelemetry/provider.py new file mode 100644 index 0000000000..1fcfd4e9c3 --- /dev/null +++ b/faststream/redis/opentelemetry/provider.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING, Sized, cast + +from opentelemetry.semconv.trace import SpanAttributes + +from faststream.opentelemetry import TelemetrySettingsProvider +from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME + +if TYPE_CHECKING: + from faststream.broker.message import StreamMessage + from faststream.types import AnyDict + + +class RedisTelemetrySettingsProvider(TelemetrySettingsProvider["AnyDict"]): + __slots__ = ("messaging_system",) + + def __init__(self) -> None: + self.messaging_system = "redis" + + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[AnyDict]", + ) -> "AnyDict": + attrs = { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_MESSAGE_ID: msg.message_id, + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: msg.correlation_id, + SpanAttributes.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES: len(msg.body), + MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message["channel"], + } + + if cast(str, msg.raw_message.get("type", "")).startswith("b"): + attrs[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = len( + cast(Sized, msg.decoded_body) + ) + + return attrs + + def get_consume_destination_name( + self, + msg: "StreamMessage[AnyDict]", + ) -> str: + return self._get_destination(msg.raw_message) + + def get_publish_attrs_from_kwargs( + self, + kwargs: "AnyDict", + ) -> "AnyDict": + return { + SpanAttributes.MESSAGING_SYSTEM: self.messaging_system, + SpanAttributes.MESSAGING_DESTINATION_NAME: self._get_destination(kwargs), + SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: kwargs["correlation_id"], + } + + def get_publish_destination_name( + self, + kwargs: "AnyDict", + ) -> str: + return self._get_destination(kwargs) + + @staticmethod + def _get_destination(kwargs: "AnyDict") -> str: + return kwargs.get("channel") or kwargs.get("list") or kwargs.get("stream") or "" diff --git a/faststream/redis/parser.py b/faststream/redis/parser.py index d47dae603d..52806b7fbd 100644 --- a/faststream/redis/parser.py +++ b/faststream/redis/parser.py @@ -1,6 +1,7 @@ from typing import ( TYPE_CHECKING, Any, + List, Mapping, Optional, Sequence, @@ -135,13 +136,16 @@ async def parse_message( self, message: Mapping[str, Any], ) -> "StreamMessage[Mapping[str, Any]]": - data, headers = self._parse_data(message) + data, headers, batch_headers = self._parse_data(message) + id_ = gen_cor_id() + return self.msg_class( raw_message=message, body=data, path=self.get_path(message), headers=headers, + batch_headers=batch_headers, reply_to=headers.get("reply_to", ""), content_type=headers.get("content-type"), message_id=headers.get("message_id", id_), @@ -149,8 +153,10 @@ async def parse_message( ) @staticmethod - def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: - return RawMessage.parse(message["data"]) + def _parse_data( + message: Mapping[str, Any], + ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: + return (*RawMessage.parse(message["data"]), []) def get_path(self, message: Mapping[str, Any]) -> "AnyDict": if ( @@ -182,10 +188,26 @@ class RedisBatchListParser(SimpleParser): msg_class = RedisBatchListMessage @staticmethod - def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: + def _parse_data( + message: Mapping[str, Any], + ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: + body: List[Any] = [] + batch_headers: List["AnyDict"] = [] + + for x in message["data"]: + msg_data, msg_headers = _decode_batch_body_item(x) + body.append(msg_data) + batch_headers.append(msg_headers) + + first_msg_headers = next(iter(batch_headers), {}) + return ( - dump_json(_decode_batch_body_item(x) for x in message["data"]), - {"content-type": ContentTypes.json}, + dump_json(body), + { + **first_msg_headers, + "content-type": ContentTypes.json.value, + }, + batch_headers, ) @@ -193,27 +215,43 @@ class RedisStreamParser(SimpleParser): msg_class = RedisStreamMessage @classmethod - def _parse_data(cls, message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: + def _parse_data( + cls, message: Mapping[str, Any] + ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: data = message["data"] - return RawMessage.parse(data.get(bDATA_KEY) or dump_json(data)) + return (*RawMessage.parse(data.get(bDATA_KEY) or dump_json(data)), []) class RedisBatchStreamParser(SimpleParser): msg_class = RedisBatchStreamMessage @staticmethod - def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: + def _parse_data( + message: Mapping[str, Any], + ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: + body: List[Any] = [] + batch_headers: List["AnyDict"] = [] + + for x in message["data"]: + msg_data, msg_headers = _decode_batch_body_item(x.get(bDATA_KEY, x)) + body.append(msg_data) + batch_headers.append(msg_headers) + + first_msg_headers = next(iter(batch_headers), {}) + return ( - dump_json( - _decode_batch_body_item(x.get(bDATA_KEY, x)) for x in message["data"] - ), - {"content-type": ContentTypes.json}, + dump_json(body), + { + **first_msg_headers, + "content-type": ContentTypes.json.value, + }, + batch_headers, ) -def _decode_batch_body_item(msg_content: bytes) -> Any: - msg_body, _ = RawMessage.parse(msg_content) +def _decode_batch_body_item(msg_content: bytes) -> Tuple[Any, "AnyDict"]: + msg_body, headers = RawMessage.parse(msg_content) try: - return json_loads(msg_body) + return json_loads(msg_body), headers except Exception: - return msg_body + return msg_body, headers diff --git a/faststream/redis/publisher/producer.py b/faststream/redis/publisher/producer.py index d5f6f23f9b..ce807aeab8 100644 --- a/faststream/redis/publisher/producer.py +++ b/faststream/redis/publisher/producer.py @@ -126,13 +126,14 @@ async def publish_batch( *msgs: "SendableMessage", list: str, correlation_id: str, + headers: Optional["AnyDict"] = None, ) -> None: batch = ( RawMessage.encode( message=msg, correlation_id=correlation_id, reply_to=None, - headers=None, + headers=headers, ) for msg in msgs ) diff --git a/faststream/redis/router.py b/faststream/redis/router.py index 632413eeeb..635f86083e 100644 --- a/faststream/redis/router.py +++ b/faststream/redis/router.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Optional, Union from typing_extensions import Annotated, Doc, deprecated @@ -99,7 +99,10 @@ class RedisRoute(SubscriberRoute): def __init__( self, call: Annotated[ - Callable[..., "SendableMessage"], + Union[ + Callable[..., "SendableMessage"], + Callable[..., Awaitable["SendableMessage"]], + ], Doc( "Message handler function " "to wrap the same with `@broker.subscriber(...)` way." @@ -197,8 +200,8 @@ def __init__( class RedisRouter( - BrokerRouter[BaseMessage], RedisRegistrator, + BrokerRouter[BaseMessage], ): """Includable to RedisBroker router.""" diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index a35cffce60..7919f384f7 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -103,7 +103,7 @@ def setup( # type: ignore[override] logger: Optional["LoggerProto"], producer: Optional["ProducerProto"], graceful_timeout: Optional[float], - extra_context: Optional["AnyDict"], + extra_context: "AnyDict", # broker options broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], diff --git a/faststream/redis/testing.py b/faststream/redis/testing.py index 51b14dd0ba..2931bf76e2 100644 --- a/faststream/redis/testing.py +++ b/faststream/redis/testing.py @@ -1,10 +1,11 @@ import re from typing import TYPE_CHECKING, Any, Optional, Sequence, Union +from unittest.mock import AsyncMock, MagicMock from typing_extensions import override from faststream.broker.message import gen_cor_id -from faststream.exceptions import SetupError +from faststream.exceptions import WRONG_PUBLISH_ARGS, SetupError from faststream.redis.broker.broker import RedisBroker from faststream.redis.message import ( BatchListMessage, @@ -49,12 +50,15 @@ def f(msg: Any) -> None: return sub.calls[0].handler @staticmethod - async def _fake_connect( + async def _fake_connect( # type: ignore[override] broker: RedisBroker, *args: Any, **kwargs: Any, - ) -> None: + ) -> AsyncMock: broker._producer = FakeProducer(broker) # type: ignore[assignment] + connection = MagicMock() + connection.pubsub.side_effect = AsyncMock + return connection @staticmethod def remove_publisher_fake_subscriber( @@ -87,6 +91,9 @@ async def publish( # type: ignore[override] rpc_timeout: Optional[float] = 30.0, raise_timeout: bool = False, ) -> Optional[Any]: + if rpc and reply_to: + raise WRONG_PUBLISH_ARGS + correlation_id = correlation_id or gen_cor_id() body = build_message( @@ -176,6 +183,7 @@ async def publish_batch( self, *msgs: "SendableMessage", list: str, + headers: Optional["AnyDict"] = None, correlation_id: Optional[str] = None, ) -> None: correlation_id = correlation_id or gen_cor_id() @@ -193,6 +201,7 @@ async def publish_batch( build_message( m, correlation_id=correlation_id, + headers=headers, ) for m in msgs ], diff --git a/faststream/testing/broker.py b/faststream/testing/broker.py index b278244957..249e5c6846 100644 --- a/faststream/testing/broker.py +++ b/faststream/testing/broker.py @@ -16,6 +16,7 @@ from unittest.mock import AsyncMock, MagicMock from faststream.broker.core.usecase import BrokerUsecase +from faststream.broker.message import StreamMessage, decode_message, encode_message from faststream.broker.middlewares.logging import CriticalLogMiddleware from faststream.broker.wrapper.call import HandlerCallWrapper from faststream.testing.app import TestApp @@ -215,6 +216,11 @@ async def call_handler( result = await handler.consume(message) if rpc: - return result + message_body, content_type = encode_message(result) + msg_to_publish = StreamMessage( + raw_message=None, body=message_body, content_type=content_type + ) + consumed_data = decode_message(msg_to_publish) + return consumed_data return None diff --git a/faststream/types.py b/faststream/types.py index 9f12fb9d57..681a7a3b18 100644 --- a/faststream/types.py +++ b/faststream/types.py @@ -63,22 +63,16 @@ class StandardDataclass(Protocol): """Protocol to check type is dataclass.""" __dataclass_fields__: ClassVar[Dict[str, Any]] - __dataclass_params__: ClassVar[Any] - __post_init__: ClassVar[Callable[..., None]] - - def __init__(self, *args: object, **kwargs: object) -> None: - """Interface method.""" - ... BaseSendableMessage: TypeAlias = Union[ JsonDecodable, Decimal, datetime, - None, StandardDataclass, SendableTable, SendableArray, + None, ] try: diff --git a/faststream/utils/path.py b/faststream/utils/path.py index 96165d81f7..639a54ee06 100644 --- a/faststream/utils/path.py +++ b/faststream/utils/path.py @@ -11,7 +11,7 @@ def compile_path( replace_symbol: str, patch_regex: Callable[[str], str] = lambda x: x, ) -> Tuple[Optional[Pattern[str]], str]: - path_regex = "^" + path_regex = "^.*" original_path = "" idx = 0 diff --git a/pyproject.toml b/pyproject.toml index 418680ffdb..0ef08d8502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,15 +73,19 @@ nats = ["nats-py>=2.3.1,<=3.0.0"] redis = ["redis>=5.0.0,<6.0.0"] +otel = ["opentelemetry-sdk>=1.24.0,<2.0.0"] + # dev dependencies +optionals = ["faststream[rabbit,kafka,confluent,nats,redis,otel]"] + devdocs = [ - "mkdocs-material==9.5.17", - "mkdocs-static-i18n==1.2.2", + "mkdocs-material==9.5.21", + "mkdocs-static-i18n==1.2.3", "mdx-include==1.4.2", - "mkdocstrings[python]==0.24.3", + "mkdocstrings[python]==0.25.1", "mkdocs-literate-nav==0.6.1", - "mkdocs-git-revision-date-localized-plugin==1.2.4", - "mike==2.0.0", # versioning + "mkdocs-git-revision-date-localized-plugin==1.2.5", + "mike==2.1.1", # versioning "mkdocs-minify-plugin==0.8.0", "mkdocs-macros-plugin==1.0.5", # includes with variables "mkdocs-glightbox==0.3.7", # img zoom @@ -92,8 +96,8 @@ devdocs = [ ] types = [ - "faststream[rabbit,confluent,kafka,nats,redis]", - "mypy==1.9.0", + "faststream[optionals]", + "mypy==1.10.0", # mypy extensions "types-PyYAML", "types-setuptools", @@ -106,22 +110,22 @@ types = [ lint = [ "faststream[types]", - "ruff==0.3.7", + "ruff==0.4.4", "bandit==1.7.8", - "semgrep==1.68.0", + "semgrep==1.70.0", "codespell==2.2.6", ] test-core = [ - "coverage[toml]==7.4.4", - "pytest==8.1.1", + "coverage[toml]==7.5.1", + "pytest==8.2.0", "pytest-asyncio==0.23.6", "dirty-equals==0.7.1.post0", ] testing = [ "faststream[test-core]", - "fastapi==0.110.1", + "fastapi==0.111.0", "pydantic-settings>=2.0.0,<3.0.0", "httpx==0.27.0", "PyYAML==6.0.1", @@ -130,10 +134,10 @@ testing = [ ] dev = [ - "faststream[rabbit,kafka,confluent,nats,redis,lint,testing,devdocs]", + "faststream[optionals,lint,testing,devdocs]", "pre-commit==3.5.0; python_version < '3.9'", "pre-commit==3.7.0; python_version >= '3.9'", - "detect-secrets==1.4.0", + "detect-secrets==1.5.0", ] [project.urls] diff --git a/tests/asyncapi/confluent/__init__.py b/tests/asyncapi/confluent/__init__.py index e69de29bb2..c4a1803708 100644 --- a/tests/asyncapi/confluent/__init__.py +++ b/tests/asyncapi/confluent/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("confluent_kafka") diff --git a/tests/asyncapi/kafka/__init__.py b/tests/asyncapi/kafka/__init__.py index e69de29bb2..bd6bc708fc 100644 --- a/tests/asyncapi/kafka/__init__.py +++ b/tests/asyncapi/kafka/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aiokafka") diff --git a/tests/asyncapi/nats/__init__.py b/tests/asyncapi/nats/__init__.py index e69de29bb2..87ead90ee6 100644 --- a/tests/asyncapi/nats/__init__.py +++ b/tests/asyncapi/nats/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("nats") diff --git a/tests/asyncapi/rabbit/__init__.py b/tests/asyncapi/rabbit/__init__.py index e69de29bb2..ebec43fcd5 100644 --- a/tests/asyncapi/rabbit/__init__.py +++ b/tests/asyncapi/rabbit/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aio_pika") diff --git a/tests/asyncapi/rabbit/test_router.py b/tests/asyncapi/rabbit/test_router.py index b878eac005..386f4960f5 100644 --- a/tests/asyncapi/rabbit/test_router.py +++ b/tests/asyncapi/rabbit/test_router.py @@ -63,7 +63,7 @@ async def handle(msg): ... "subscribe": { "bindings": { "amqp": { - "cc": "key", + "cc": "test_key", "ack": True, "bindingVersion": "0.2.0", } @@ -91,7 +91,7 @@ async def handle(msg): ... }, }, } - ) + ), schema class TestRouterArguments(ArgumentsTestcase): diff --git a/tests/asyncapi/redis/__init__.py b/tests/asyncapi/redis/__init__.py index e69de29bb2..4752ef19b1 100644 --- a/tests/asyncapi/redis/__init__.py +++ b/tests/asyncapi/redis/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("redis") diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index 654d3b19f8..7b7b5bdd6b 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -1,4 +1,5 @@ import asyncio +from abc import abstractmethod from typing import Any, ClassVar, Dict from unittest.mock import MagicMock @@ -15,25 +16,29 @@ class BrokerConsumeTestcase: timeout: int = 3 subscriber_kwargs: ClassVar[Dict[str, Any]] = {} - @pytest.fixture() - def consume_broker(self, broker: BrokerUsecase): + @abstractmethod + def get_broker(self, broker: BrokerUsecase) -> BrokerUsecase[Any, Any]: + raise NotImplementedError + + def patch_broker(self, broker: BrokerUsecase[Any, Any]) -> BrokerUsecase[Any, Any]: return broker async def test_consume( self, queue: str, - consume_broker: BrokerUsecase, event: asyncio.Event, ): + consume_broker = self.get_broker() + @consume_broker.subscriber(queue, **self.subscriber_kwargs) def subscriber(m): event.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(consume_broker.publish("hello", queue)), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -44,9 +49,10 @@ def subscriber(m): async def test_consume_from_multi( self, queue: str, - consume_broker: BrokerUsecase, mock: MagicMock, ): + consume_broker = self.get_broker() + consume = asyncio.Event() consume2 = asyncio.Event() @@ -59,12 +65,12 @@ def subscriber(m): else: consume2.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(consume_broker.publish("hello", queue)), - asyncio.create_task(consume_broker.publish("hello", queue + "1")), + asyncio.create_task(br.publish("hello", queue)), + asyncio.create_task(br.publish("hello", queue + "1")), asyncio.create_task(consume.wait()), asyncio.create_task(consume2.wait()), ), @@ -78,9 +84,10 @@ def subscriber(m): async def test_consume_double( self, queue: str, - consume_broker: BrokerUsecase, mock: MagicMock, ): + consume_broker = self.get_broker() + consume = asyncio.Event() consume2 = asyncio.Event() @@ -92,12 +99,12 @@ async def handler(m): else: consume2.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(consume_broker.publish("hello", queue)), - asyncio.create_task(consume_broker.publish("hello", queue)), + asyncio.create_task(br.publish("hello", queue)), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(consume.wait()), asyncio.create_task(consume2.wait()), ), @@ -111,9 +118,10 @@ async def handler(m): async def test_different_consume( self, queue: str, - consume_broker: BrokerUsecase, mock: MagicMock, ): + consume_broker = self.get_broker() + consume = asyncio.Event() consume2 = asyncio.Event() @@ -129,12 +137,12 @@ def handler2(m): mock.handler2() consume2.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(consume_broker.publish("hello", queue)), - asyncio.create_task(consume_broker.publish("hello", another_topic)), + asyncio.create_task(br.publish("hello", queue)), + asyncio.create_task(br.publish("hello", another_topic)), asyncio.create_task(consume.wait()), asyncio.create_task(consume2.wait()), ), @@ -149,9 +157,10 @@ def handler2(m): async def test_consume_with_filter( self, queue: str, - consume_broker: BrokerUsecase, mock: MagicMock, ): + consume_broker = self.get_broker() + consume = asyncio.Event() consume2 = asyncio.Event() @@ -169,14 +178,12 @@ async def handler2(m): mock.handler2(m) consume2.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task( - consume_broker.publish({"msg": "hello"}, queue) - ), - asyncio.create_task(consume_broker.publish("hello", queue)), + asyncio.create_task(br.publish({"msg": "hello"}, queue)), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(consume.wait()), asyncio.create_task(consume2.wait()), ), @@ -191,10 +198,11 @@ async def handler2(m): async def test_consume_validate_false( self, queue: str, - consume_broker: BrokerUsecase, event: asyncio.Event, mock: MagicMock, ): + consume_broker = self.get_broker() + consume_broker._is_apply_types = True consume_broker._is_validate = False @@ -209,17 +217,49 @@ async def handler(m: Foo, dep: int = Depends(dependency), broker=Context()): mock(m, dep, broker) event.set() - await consume_broker.start() - await asyncio.wait( - ( - asyncio.create_task(consume_broker.publish({"x": 1}, queue)), - asyncio.create_task(event.wait()), - ), - timeout=self.timeout, - ) + async with self.patch_broker(consume_broker) as br: + await br.start() + + await asyncio.wait( + ( + asyncio.create_task(br.publish({"x": 1}, queue)), + asyncio.create_task(event.wait()), + ), + timeout=self.timeout, + ) + + assert event.is_set() + mock.assert_called_once_with({"x": 1}, "100", consume_broker) + + async def test_dynamic_sub( + self, + queue: str, + event: asyncio.Event, + ): + consume_broker = self.get_broker() + + def subscriber(m): + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + + sub = br.subscriber(queue, **self.subscriber_kwargs) + sub(subscriber) + br.setup_subscriber(sub) + await sub.start() + + await asyncio.wait( + ( + asyncio.create_task(br.publish("hello", queue)), + asyncio.create_task(event.wait()), + ), + timeout=self.timeout, + ) + + await sub.close() assert event.is_set() - mock.assert_called_once_with({"x": 1}, "100", consume_broker) @pytest.mark.asyncio() @@ -228,27 +268,28 @@ class BrokerRealConsumeTestcase(BrokerConsumeTestcase): async def test_stop_consume_exc( self, queue: str, - consume_broker: BrokerUsecase, event: asyncio.Event, mock: MagicMock, ): + consume_broker = self.get_broker() + @consume_broker.subscriber(queue, **self.subscriber_kwargs) def subscriber(m): mock() event.set() raise StopConsume() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(consume_broker.publish("hello", queue)), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, ) await asyncio.sleep(0.5) - await consume_broker.publish("hello", queue) + await br.publish("hello", queue) await asyncio.sleep(0.5) assert event.is_set() diff --git a/tests/brokers/base/middlewares.py b/tests/brokers/base/middlewares.py index 4f89f08411..7ed74522d8 100644 --- a/tests/brokers/base/middlewares.py +++ b/tests/brokers/base/middlewares.py @@ -270,6 +270,59 @@ async def handler(m): mock.start.assert_called_once() mock.end.assert_called_once() + async def test_add_global_middleware( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + raw_broker, + ): + class mid(BaseMiddleware): # noqa: N801 + async def on_receive(self): + mock.start(self.msg) + return await super().on_receive() + + async def after_processed(self, exc_type, exc_val, exc_tb): + mock.end() + return await super().after_processed(exc_type, exc_val, exc_tb) + + broker = self.broker_class() + + # already registered subscriber + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(m): + event.set() + return "" + + # should affect to already registered and a new subscriber both + broker.add_middleware(mid) + + event2 = asyncio.Event() + + # new subscriber + @broker.subscriber(f"{queue}1", **self.subscriber_kwargs) + async def handler2(m): + event2.set() + return "" + + broker = self.patch_broker(raw_broker, broker) + + async with broker: + await broker.start() + await asyncio.wait( + ( + asyncio.create_task(broker.publish("", queue)), + asyncio.create_task(broker.publish("", f"{queue}1")), + asyncio.create_task(event.wait()), + asyncio.create_task(event2.wait()), + ), + timeout=self.timeout, + ) + + assert event.is_set() + assert mock.start.call_count == 2 + assert mock.end.call_count == 2 + async def test_patch_publish(self, queue: str, mock: Mock, event, raw_broker): class Mid(BaseMiddleware): async def on_publish(self, msg: str, *args, **kwargs) -> str: diff --git a/tests/brokers/base/publish.py b/tests/brokers/base/publish.py index 2ed026c9f7..327f31627b 100644 --- a/tests/brokers/base/publish.py +++ b/tests/brokers/base/publish.py @@ -1,4 +1,6 @@ import asyncio +from abc import abstractmethod +from dataclasses import asdict, dataclass from datetime import datetime from typing import Any, ClassVar, Dict, List, Tuple from unittest.mock import Mock @@ -7,7 +9,7 @@ import pytest from pydantic import BaseModel -from faststream._compat import model_to_json +from faststream._compat import dump_json, model_to_json from faststream.annotations import Logger from faststream.broker.core.usecase import BrokerUsecase @@ -16,6 +18,11 @@ class SimpleModel(BaseModel): r: str +@dataclass +class SimpleDataclass: + r: str + + now = datetime.now() @@ -23,9 +30,12 @@ class BrokerPublishTestcase: timeout: int = 3 subscriber_kwargs: ClassVar[Dict[str, Any]] = {} - @pytest.fixture() - def pub_broker(self, full_broker): - return full_broker + @abstractmethod + def get_broker(self, apply_types: bool = False) -> BrokerUsecase[Any, Any]: + raise NotImplementedError + + def patch_broker(self, broker: BrokerUsecase[Any, Any]) -> BrokerUsecase[Any, Any]: + return broker @pytest.mark.asyncio() @pytest.mark.parametrize( @@ -55,6 +65,12 @@ def pub_broker(self, full_broker): 1.0, id="float->float", ), + pytest.param( + 1, + float, + 1.0, + id="int->float", + ), pytest.param( False, bool, @@ -103,11 +119,34 @@ def pub_broker(self, full_broker): SimpleModel(r="hello!"), id="dict->model", ), + pytest.param( + dump_json(asdict(SimpleDataclass(r="hello!"))), + SimpleDataclass, + SimpleDataclass(r="hello!"), + id="bytes->dataclass", + ), + pytest.param( + SimpleDataclass(r="hello!"), + SimpleDataclass, + SimpleDataclass(r="hello!"), + id="dataclass->dataclass", + ), + pytest.param( + SimpleDataclass(r="hello!"), + dict, + {"r": "hello!"}, + id="dataclass->dict", + ), + pytest.param( + {"r": "hello!"}, + SimpleDataclass, + SimpleDataclass(r="hello!"), + id="dict->dataclass", + ), ), ) async def test_serialize( self, - pub_broker: BrokerUsecase, mock: Mock, queue: str, message, @@ -115,17 +154,19 @@ async def test_serialize( expected_message, event, ): + pub_broker = self.get_broker(apply_types=True) + @pub_broker.subscriber(queue, **self.subscriber_kwargs) async def handler(m: message_type, logger: Logger): event.set() mock(m) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(pub_broker.publish(message, queue)), + asyncio.create_task(br.publish(message, queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -136,18 +177,23 @@ async def handler(m: message_type, logger: Logger): @pytest.mark.asyncio() async def test_unwrap_dict( - self, mock: Mock, queue: str, pub_broker: BrokerUsecase, event + self, + mock: Mock, + queue: str, + event, ): + pub_broker = self.get_broker(apply_types=True) + @pub_broker.subscriber(queue, **self.subscriber_kwargs) async def m(a: int, b: int, logger: Logger): event.set() mock({"a": a, "b": b}) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(pub_broker.publish({"a": 1, "b": 1.0}, queue)), + asyncio.create_task(br.publish({"a": 1, "b": 1.0}, queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -163,18 +209,23 @@ async def m(a: int, b: int, logger: Logger): @pytest.mark.asyncio() async def test_unwrap_list( - self, mock: Mock, queue: str, pub_broker: BrokerUsecase, event: asyncio.Event + self, + mock: Mock, + queue: str, + event: asyncio.Event, ): + pub_broker = self.get_broker(apply_types=True) + @pub_broker.subscriber(queue, **self.subscriber_kwargs) async def m(a: int, b: int, *args: Tuple[int, ...], logger: Logger): event.set() mock({"a": a, "b": b, "args": args}) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(pub_broker.publish([1, 1.0, 2.0, 3.0], queue)), + asyncio.create_task(br.publish([1, 1.0, 2.0, 3.0], queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -187,10 +238,11 @@ async def m(a: int, b: int, *args: Tuple[int, ...], logger: Logger): async def test_base_publisher( self, queue: str, - pub_broker: BrokerUsecase, event, mock, ): + pub_broker = self.get_broker(apply_types=True) + @pub_broker.subscriber(queue, **self.subscriber_kwargs) @pub_broker.publisher(queue + "resp") async def m(): @@ -201,11 +253,11 @@ async def resp(msg): event.set() mock(msg) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(pub_broker.publish("", queue)), + asyncio.create_task(br.publish("", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -218,10 +270,11 @@ async def resp(msg): async def test_publisher_object( self, queue: str, - pub_broker: BrokerUsecase, event, mock, ): + pub_broker = self.get_broker(apply_types=True) + publisher = pub_broker.publisher(queue + "resp") @publisher @@ -234,11 +287,11 @@ async def resp(msg): event.set() mock(msg) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(pub_broker.publish("", queue)), + asyncio.create_task(br.publish("", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -251,10 +304,11 @@ async def resp(msg): async def test_publish_manual( self, queue: str, - pub_broker: BrokerUsecase, event, mock, ): + pub_broker = self.get_broker(apply_types=True) + publisher = pub_broker.publisher(queue + "resp") @pub_broker.subscriber(queue, **self.subscriber_kwargs) @@ -266,11 +320,11 @@ async def resp(msg): event.set() mock(msg) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(pub_broker.publish("", queue)), + asyncio.create_task(br.publish("", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -281,8 +335,12 @@ async def resp(msg): @pytest.mark.asyncio() async def test_multiple_publishers( - self, queue: str, pub_broker: BrokerUsecase, mock + self, + queue: str, + mock, ): + pub_broker = self.get_broker(apply_types=True) + event = anyio.Event() event2 = anyio.Event() @@ -302,11 +360,11 @@ async def resp2(msg): event2.set() mock.resp2(msg) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(pub_broker.publish("", queue)), + asyncio.create_task(br.publish("", queue)), asyncio.create_task(event.wait()), asyncio.create_task(event2.wait()), ), @@ -320,8 +378,12 @@ async def resp2(msg): @pytest.mark.asyncio() async def test_reusable_publishers( - self, queue: str, pub_broker: BrokerUsecase, mock + self, + queue: str, + mock, ): + pub_broker = self.get_broker(apply_types=True) + consume = anyio.Event() consume2 = anyio.Event() @@ -345,12 +407,12 @@ async def resp(): consume2.set() mock() - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(pub_broker.publish("", queue)), - asyncio.create_task(pub_broker.publish("", queue + "2")), + asyncio.create_task(br.publish("", queue)), + asyncio.create_task(br.publish("", queue + "2")), asyncio.create_task(consume.wait()), asyncio.create_task(consume2.wait()), ), @@ -364,11 +426,12 @@ async def resp(): @pytest.mark.asyncio() async def test_reply_to( self, - pub_broker: BrokerUsecase, queue: str, event, mock, ): + pub_broker = self.get_broker(apply_types=True) + @pub_broker.subscriber(queue + "reply", **self.subscriber_kwargs) async def reply_handler(m): event.set() @@ -378,13 +441,13 @@ async def reply_handler(m): async def handler(m): return m - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await asyncio.wait( ( asyncio.create_task( - pub_broker.publish("Hello!", queue, reply_to=queue + "reply") + br.publish("Hello!", queue, reply_to=queue + "reply") ), asyncio.create_task(event.wait()), ), @@ -397,20 +460,21 @@ async def handler(m): @pytest.mark.asyncio() async def test_publisher_after_start( self, - pub_broker: BrokerUsecase, queue: str, event, mock, ): + pub_broker = self.get_broker(apply_types=True) + @pub_broker.subscriber(queue, **self.subscriber_kwargs) async def handler(m): event.set() mock(m) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() - pub = pub_broker.publisher(queue) + pub = br.publisher(queue) await asyncio.wait( ( diff --git a/tests/brokers/base/router.py b/tests/brokers/base/router.py index d22f5e919d..1361f4c9b5 100644 --- a/tests/brokers/base/router.py +++ b/tests/brokers/base/router.py @@ -381,7 +381,7 @@ def subscriber(): ... pub_broker.include_routers(router) sub = next(iter(pub_broker._subscribers.values())) - assert len((*sub._broker_dependecies, *sub.calls[0].dependencies)) == 3 + assert len((*sub._broker_dependencies, *sub.calls[0].dependencies)) == 3 async def test_router_include_with_dependencies( self, @@ -402,7 +402,7 @@ def subscriber(): ... pub_broker.include_router(router, dependencies=(Depends(lambda: 1),)) sub = next(iter(pub_broker._subscribers.values())) - dependencies = (*sub._broker_dependecies, *sub.calls[0].dependencies) + dependencies = (*sub._broker_dependencies, *sub.calls[0].dependencies) assert len(dependencies) == 3, dependencies async def test_router_middlewares( diff --git a/tests/brokers/base/rpc.py b/tests/brokers/base/rpc.py index d4741b3db9..e544360bc5 100644 --- a/tests/brokers/base/rpc.py +++ b/tests/brokers/base/rpc.py @@ -1,4 +1,6 @@ import asyncio +from abc import abstractstaticmethod +from typing import Any from unittest.mock import MagicMock import anyio @@ -9,33 +11,40 @@ class BrokerRPCTestcase: - @pytest.fixture() - def rpc_broker(self, broker): + @abstractstaticmethod + def get_broker(self, apply_types: bool = False) -> BrokerUsecase[Any, Any]: + raise NotImplementedError + + def patch_broker(self, broker: BrokerUsecase[Any, Any]) -> BrokerUsecase[Any, Any]: return broker @pytest.mark.asyncio() - async def test_rpc(self, queue: str, rpc_broker: BrokerUsecase): + async def test_rpc(self, queue: str): + rpc_broker = self.get_broker() + @rpc_broker.subscriber(queue) async def m(m): # pragma: no cover return "1" - async with rpc_broker: - await rpc_broker.start() - r = await rpc_broker.publish("hello", queue, rpc_timeout=3, rpc=True) + async with self.patch_broker(rpc_broker) as br: + await br.start() + r = await br.publish("hello", queue, rpc_timeout=3, rpc=True) assert r == "1" @pytest.mark.asyncio() - async def test_rpc_timeout_raises(self, queue: str, rpc_broker: BrokerUsecase): + async def test_rpc_timeout_raises(self, queue: str): + rpc_broker = self.get_broker() + @rpc_broker.subscriber(queue) async def m(m): # pragma: no cover await anyio.sleep(1) - async with rpc_broker: - await rpc_broker.start() + async with self.patch_broker(rpc_broker) as br: + await br.start() with pytest.raises(TimeoutError): # pragma: no branch - await rpc_broker.publish( + await br.publish( "hello", queue, rpc=True, @@ -44,15 +53,17 @@ async def m(m): # pragma: no cover ) @pytest.mark.asyncio() - async def test_rpc_timeout_none(self, queue: str, rpc_broker: BrokerUsecase): + async def test_rpc_timeout_none(self, queue: str): + rpc_broker = self.get_broker() + @rpc_broker.subscriber(queue) async def m(m): # pragma: no cover await anyio.sleep(1) - async with rpc_broker: - await rpc_broker.start() + async with self.patch_broker(rpc_broker) as br: + await br.start() - r = await rpc_broker.publish( + r = await br.publish( "hello", queue, rpc=True, @@ -65,10 +76,11 @@ async def m(m): # pragma: no cover async def test_rpc_with_reply( self, queue: str, - rpc_broker: BrokerUsecase, mock: MagicMock, event: asyncio.Event, ): + rpc_broker = self.get_broker() + reply_queue = queue + "1" @rpc_broker.subscriber(reply_queue) @@ -80,10 +92,10 @@ async def response_hanler(m: str): async def m(m): # pragma: no cover return "1" - async with rpc_broker: - await rpc_broker.start() + async with self.patch_broker(rpc_broker) as br: + await br.start() - await rpc_broker.publish("hello", queue, reply_to=reply_queue) + await br.publish("hello", queue, reply_to=reply_queue) with timeout_scope(3, True): await event.wait() @@ -93,12 +105,15 @@ async def m(m): # pragma: no cover class ReplyAndConsumeForbidden: @pytest.mark.asyncio() - async def test_rpc_with_reply_and_callback(self, full_broker: BrokerUsecase): - with pytest.raises(ValueError): # noqa: PT011 - await full_broker.publish( - "hello", - "some", - reply_to="some", - rpc=True, - rpc_timeout=0, - ) + async def test_rpc_with_reply_and_callback(self): + rpc_broker = self.get_broker() + + async with rpc_broker: + with pytest.raises(ValueError): # noqa: PT011 + await rpc_broker.publish( + "hello", + "some", + reply_to="some", + rpc=True, + rpc_timeout=0, + ) diff --git a/tests/brokers/base/testclient.py b/tests/brokers/base/testclient.py index 2112519c89..8381c95dc1 100644 --- a/tests/brokers/base/testclient.py +++ b/tests/brokers/base/testclient.py @@ -1,6 +1,6 @@ import pytest -from faststream.broker.core.usecase import BrokerUsecase +from faststream.testing.broker import TestBroker from faststream.types import AnyCallable from tests.brokers.base.consume import BrokerConsumeTestcase from tests.brokers.base.publish import BrokerPublishTestcase @@ -13,61 +13,62 @@ class BrokerTestclientTestcase( BrokerRPCTestcase, ): build_message: AnyCallable - - @pytest.fixture() - def pub_broker(self, test_broker): - return test_broker - - @pytest.fixture() - def consume_broker(self, test_broker): - return test_broker - - @pytest.fixture() - def rpc_broker(self, test_broker): - return test_broker + test_class: TestBroker @pytest.mark.asyncio() - async def test_subscriber_mock(self, queue: str, test_broker: BrokerUsecase): + async def test_subscriber_mock(self, queue: str): + test_broker = self.get_broker() + @test_broker.subscriber(queue) - async def m(): + async def m(msg): pass - await test_broker.start() - await test_broker.publish("hello", queue) - m.mock.assert_called_once_with("hello") + async with self.test_class(test_broker): + await test_broker.start() + await test_broker.publish("hello", queue) + m.mock.assert_called_once_with("hello") @pytest.mark.asyncio() - async def test_publisher_mock(self, queue: str, test_broker: BrokerUsecase): + async def test_publisher_mock(self, queue: str): + test_broker = self.get_broker() + publisher = test_broker.publisher(queue + "resp") @publisher @test_broker.subscriber(queue) - async def m(): + async def m(msg): return "response" - await test_broker.start() - await test_broker.publish("hello", queue) - publisher.mock.assert_called_with("response") + async with self.test_class(test_broker): + await test_broker.start() + await test_broker.publish("hello", queue) + publisher.mock.assert_called_with("response") @pytest.mark.asyncio() - async def test_manual_publisher_mock(self, queue: str, test_broker: BrokerUsecase): + async def test_manual_publisher_mock(self, queue: str): + test_broker = self.get_broker() + publisher = test_broker.publisher(queue + "resp") @test_broker.subscriber(queue) - async def m(): + async def m(msg): await publisher.publish("response") - await test_broker.start() - await test_broker.publish("hello", queue) - publisher.mock.assert_called_with("response") + async with self.test_class(test_broker): + await test_broker.start() + await test_broker.publish("hello", queue) + publisher.mock.assert_called_with("response") @pytest.mark.asyncio() - async def test_exception_raises(self, queue: str, test_broker: BrokerUsecase): + async def test_exception_raises(self, queue: str): + test_broker = self.get_broker() + @test_broker.subscriber(queue) - async def m(): # pragma: no cover + async def m(msg): # pragma: no cover raise ValueError() - await test_broker.start() + async with self.test_class(test_broker): + await test_broker.start() - with pytest.raises(ValueError): # noqa: PT011 - await test_broker.publish("hello", queue) + with pytest.raises(ValueError): # noqa: PT011 + await test_broker.publish("hello", queue) diff --git a/tests/brokers/confluent/__init__.py b/tests/brokers/confluent/__init__.py index e69de29bb2..c4a1803708 100644 --- a/tests/brokers/confluent/__init__.py +++ b/tests/brokers/confluent/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("confluent_kafka") diff --git a/tests/brokers/confluent/conftest.py b/tests/brokers/confluent/conftest.py index d128af04c1..aaac741a25 100644 --- a/tests/brokers/confluent/conftest.py +++ b/tests/brokers/confluent/conftest.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from uuid import uuid4 import pytest import pytest_asyncio @@ -31,12 +30,6 @@ async def broker(settings): yield broker -@pytest_asyncio.fixture(scope="session") -async def confluent_kafka_topic(settings): - topic = str(uuid4()) - return topic - - @pytest_asyncio.fixture async def full_broker(settings): broker = KafkaBroker(settings.url) diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index fb612d66d0..805b3a97f2 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -19,18 +19,23 @@ class TestConsume(BrokerRealConsumeTestcase): timeout: int = 10 subscriber_kwargs: ClassVar[Dict[str, Any]] = {"auto_offset_reset": "earliest"} + def get_broker(self, apply_types: bool = False): + return KafkaBroker(apply_types=apply_types) + @pytest.mark.asyncio() - async def test_consume_batch(self, confluent_kafka_topic: str, broker: KafkaBroker): + async def test_consume_batch(self, queue: str): + consume_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(confluent_kafka_topic, batch=True, **self.subscriber_kwargs) + @consume_broker.subscriber(queue, batch=True, **self.subscriber_kwargs) async def handler(msg): await msgs_queue.put(msg) - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() - await broker.publish_batch(1, "hi", topic=confluent_kafka_topic) + await br.publish_batch(1, "hi", topic=queue) result, _ = await asyncio.wait( (asyncio.create_task(msgs_queue.get()),), @@ -39,22 +44,58 @@ async def handler(msg): assert [{1, "hi"}] == [set(r.result()) for r in result] + @pytest.mark.asyncio() + async def test_consume_batch_headers( + self, + mock, + event: asyncio.Event, + queue: str, + ): + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, batch=True, **self.subscriber_kwargs) + def subscriber(m, msg: KafkaMessage): + check = all( + ( + msg.headers, + [msg.headers] == msg.batch_headers, + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + + await asyncio.wait( + ( + asyncio.create_task(br.publish("", queue, headers={"custom": "1"})), + asyncio.create_task(event.wait()), + ), + timeout=self.timeout, + ) + + assert event.is_set() + mock.assert_called_once_with(True) + @pytest.mark.asyncio() @pytest.mark.slow() async def test_consume_ack( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber( + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( queue, group_id="test", auto_commit=False, **self.subscriber_kwargs ) async def handler(msg: KafkaMessage): event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object( AsyncConfluentConsumer, @@ -64,7 +105,7 @@ async def handler(msg: KafkaMessage): await asyncio.wait( ( asyncio.create_task( - full_broker.publish( + br.publish( "hello", queue, ) @@ -82,18 +123,19 @@ async def handler(msg: KafkaMessage): async def test_consume_ack_manual( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber( + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( queue, group_id="test", auto_commit=False, **self.subscriber_kwargs ) async def handler(msg: KafkaMessage): await msg.ack() event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object( AsyncConfluentConsumer, @@ -102,12 +144,7 @@ async def handler(msg: KafkaMessage): ) as m: await asyncio.wait( ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) - ), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -121,18 +158,19 @@ async def handler(msg: KafkaMessage): async def test_consume_ack_raise( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber( + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( queue, group_id="test", auto_commit=False, **self.subscriber_kwargs ) async def handler(msg: KafkaMessage): event.set() raise AckMessage() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object( AsyncConfluentConsumer, @@ -141,12 +179,7 @@ async def handler(msg: KafkaMessage): ) as m: await asyncio.wait( ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) - ), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -160,18 +193,19 @@ async def handler(msg: KafkaMessage): async def test_nack( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber( + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( queue, group_id="test", auto_commit=False, **self.subscriber_kwargs ) async def handler(msg: KafkaMessage): await msg.nack() event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object( AsyncConfluentConsumer, @@ -180,12 +214,7 @@ async def handler(msg: KafkaMessage): ) as m: await asyncio.wait( ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) - ), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, @@ -199,34 +228,37 @@ async def handler(msg: KafkaMessage): async def test_consume_no_ack( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber( + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( queue, group_id="test", no_ack=True, **self.subscriber_kwargs ) async def handler(msg: KafkaMessage): event.set() - await full_broker.start() - with patch.object( - AsyncConfluentConsumer, - "commit", - spy_decorator(AsyncConfluentConsumer.commit), - ) as m: - await asyncio.wait( - ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) + async with self.patch_broker(consume_broker) as br: + await br.start() + + with patch.object( + AsyncConfluentConsumer, + "commit", + spy_decorator(AsyncConfluentConsumer.commit), + ) as m: + await asyncio.wait( + ( + asyncio.create_task( + br.publish( + "hello", + queue, + ) + ), + asyncio.create_task(event.wait()), ), - asyncio.create_task(event.wait()), - ), - timeout=self.timeout, - ) - m.mock.assert_not_called() + timeout=self.timeout, + ) + m.mock.assert_not_called() assert event.is_set() @@ -235,17 +267,18 @@ async def handler(msg: KafkaMessage): async def test_consume_with_no_auto_commit( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber( + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( queue, auto_commit=False, group_id="test", **self.subscriber_kwargs ) async def subscriber_no_auto_commit(msg: KafkaMessage): await msg.nack() event.set() - broker2 = KafkaBroker() + broker2 = self.get_broker() event2 = asyncio.Event() @broker2.subscriber( @@ -254,18 +287,20 @@ async def subscriber_no_auto_commit(msg: KafkaMessage): async def subscriber_with_auto_commit(m): event2.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task(full_broker.publish("hello", queue)), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=self.timeout, ) - async with broker2: - await broker2.start() + async with self.patch_broker(broker2) as br2: + await br2.start() + await asyncio.wait( (asyncio.create_task(event2.wait()),), timeout=self.timeout, diff --git a/tests/brokers/confluent/test_publish.py b/tests/brokers/confluent/test_publish.py index 156b150356..0fed589efb 100644 --- a/tests/brokers/confluent/test_publish.py +++ b/tests/brokers/confluent/test_publish.py @@ -12,18 +12,23 @@ class TestPublish(BrokerPublishTestcase): timeout: int = 10 subscriber_kwargs: ClassVar[Dict[str, Any]] = {"auto_offset_reset": "earliest"} + def get_broker(self, apply_types: bool = False): + return KafkaBroker(apply_types=apply_types) + @pytest.mark.asyncio() - async def test_publish_batch(self, queue: str, broker: KafkaBroker): + async def test_publish_batch(self, queue: str): + pub_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=2) - @broker.subscriber(queue, **self.subscriber_kwargs) + @pub_broker.subscriber(queue, **self.subscriber_kwargs) async def handler(msg): await msgs_queue.put(msg) - async with broker: - await broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() - await broker.publish_batch(1, "hi", topic=queue) + await br.publish_batch(1, "hi", topic=queue) result, _ = await asyncio.wait( ( @@ -36,17 +41,19 @@ async def handler(msg): assert {1, "hi"} == {r.result() for r in result} @pytest.mark.asyncio() - async def test_batch_publisher_manual(self, queue: str, broker: KafkaBroker): + async def test_batch_publisher_manual(self, queue: str): + pub_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=2) - @broker.subscriber(queue, **self.subscriber_kwargs) + @pub_broker.subscriber(queue, **self.subscriber_kwargs) async def handler(msg): await msgs_queue.put(msg) - publisher = broker.publisher(queue, batch=True) + publisher = pub_broker.publisher(queue, batch=True) - async with broker: - await broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await publisher.publish(1, "hi") @@ -61,22 +68,24 @@ async def handler(msg): assert {1, "hi"} == {r.result() for r in result} @pytest.mark.asyncio() - async def test_batch_publisher_decorator(self, queue: str, broker: KafkaBroker): + async def test_batch_publisher_decorator(self, queue: str): + pub_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=2) - @broker.subscriber(queue, **self.subscriber_kwargs) + @pub_broker.subscriber(queue, **self.subscriber_kwargs) async def handler(msg): await msgs_queue.put(msg) - @broker.publisher(queue, batch=True) - @broker.subscriber(queue + "1", **self.subscriber_kwargs) + @pub_broker.publisher(queue, batch=True) + @pub_broker.subscriber(queue + "1", **self.subscriber_kwargs) async def pub(m): return 1, "hi" - async with broker: - await broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() - await broker.publish("", queue + "1") + await br.publish("", queue + "1") result, _ = await asyncio.wait( ( diff --git a/tests/brokers/confluent/test_test_client.py b/tests/brokers/confluent/test_test_client.py index d70d2fda6d..b8e232802f 100644 --- a/tests/brokers/confluent/test_test_client.py +++ b/tests/brokers/confluent/test_test_client.py @@ -11,13 +11,22 @@ class TestTestclient(BrokerTestclientTestcase): """A class to represent a test Kafka broker.""" + test_class = TestKafkaBroker + + def get_broker(self, apply_types: bool = False): + return KafkaBroker(apply_types=apply_types) + + def patch_broker(self, broker: KafkaBroker) -> TestKafkaBroker: + return TestKafkaBroker(broker) + @pytest.mark.confluent() async def test_with_real_testclient( self, - broker: KafkaBroker, queue: str, event: asyncio.Event, ): + broker = self.get_broker() + @broker.subscriber(queue, auto_offset_reset="earliest") def subscriber(m): event.set() @@ -35,46 +44,49 @@ def subscriber(m): async def test_batch_pub_by_default_pub( self, - test_broker: KafkaBroker, queue: str, ): - @test_broker.subscriber(queue, batch=True, auto_offset_reset="earliest") - async def m(): + broker = self.get_broker() + + @broker.subscriber(queue, batch=True, auto_offset_reset="earliest") + async def m(msg): pass - await test_broker.start() - await test_broker.publish("hello", queue) - m.mock.assert_called_once_with(["hello"]) + async with self.patch_broker(broker) as br: + await br.publish("hello", queue) + m.mock.assert_called_once_with(["hello"]) async def test_batch_pub_by_pub_batch( self, - test_broker: KafkaBroker, queue: str, ): - @test_broker.subscriber(queue, batch=True, auto_offset_reset="earliest") - async def m(): + broker = self.get_broker() + + @broker.subscriber(queue, batch=True, auto_offset_reset="earliest") + async def m(msg): pass - await test_broker.start() - await test_broker.publish_batch("hello", topic=queue) - m.mock.assert_called_once_with(["hello"]) + async with self.patch_broker(broker) as br: + await br.publish_batch("hello", topic=queue) + m.mock.assert_called_once_with(["hello"]) async def test_batch_publisher_mock( self, - test_broker: KafkaBroker, queue: str, ): - publisher = test_broker.publisher(queue + "1", batch=True) + broker = self.get_broker() + + publisher = broker.publisher(queue + "1", batch=True) @publisher - @test_broker.subscriber(queue, auto_offset_reset="earliest") - async def m(): + @broker.subscriber(queue, auto_offset_reset="earliest") + async def m(msg): return 1, 2, 3 - await test_broker.start() - await test_broker.publish("hello", queue) - m.mock.assert_called_once_with("hello") - publisher.mock.assert_called_once_with([1, 2, 3]) + async with self.patch_broker(broker) as br: + await br.publish("hello", queue) + m.mock.assert_called_once_with("hello") + publisher.mock.assert_called_once_with([1, 2, 3]) async def test_respect_middleware(self, queue): routes = [] diff --git a/tests/brokers/conftest.py b/tests/brokers/conftest.py deleted file mode 100644 index 5aac495a23..0000000000 --- a/tests/brokers/conftest.py +++ /dev/null @@ -1,8 +0,0 @@ -from uuid import uuid4 - -import pytest - - -@pytest.fixture() -def queue(): - return str(uuid4()) diff --git a/tests/brokers/kafka/__init__.py b/tests/brokers/kafka/__init__.py index e69de29bb2..bd6bc708fc 100644 --- a/tests/brokers/kafka/__init__.py +++ b/tests/brokers/kafka/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aiokafka") diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index fdef8a20bc..2a7f57b888 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -5,7 +5,7 @@ from aiokafka import AIOKafkaConsumer from faststream.exceptions import AckMessage -from faststream.kafka import KafkaBroker +from faststream.kafka import KafkaBroker, TopicPartition from faststream.kafka.annotations import KafkaMessage from tests.brokers.base.consume import BrokerRealConsumeTestcase from tests.tools import spy_decorator @@ -13,18 +13,23 @@ @pytest.mark.kafka() class TestConsume(BrokerRealConsumeTestcase): + def get_broker(self, apply_types: bool = False): + return KafkaBroker(apply_types=apply_types) + @pytest.mark.asyncio() - async def test_consume_batch(self, queue: str, broker: KafkaBroker): + async def test_consume_batch(self, queue: str): + consume_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(queue, batch=True) + @consume_broker.subscriber(queue, batch=True) async def handler(msg): await msgs_queue.put(msg) - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() - await broker.publish_batch(1, "hi", topic=queue) + await br.publish_batch(1, "hi", topic=queue) result, _ = await asyncio.wait( (asyncio.create_task(msgs_queue.get()),), @@ -33,20 +38,56 @@ async def handler(msg): assert [{1, "hi"}] == [set(r.result()) for r in result] + @pytest.mark.asyncio() + async def test_consume_batch_headers( + self, + mock, + event: asyncio.Event, + queue: str, + ): + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, batch=True) + def subscriber(m, msg: KafkaMessage): + check = all( + ( + msg.headers, + [msg.headers] == msg.batch_headers, + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + + await asyncio.wait( + ( + asyncio.create_task(br.publish("", queue, headers={"custom": "1"})), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(True) + @pytest.mark.asyncio() @pytest.mark.slow() async def test_consume_ack( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue, group_id="test", auto_commit=False) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, group_id="test", auto_commit=False) async def handler(msg: KafkaMessage): event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object( AIOKafkaConsumer, "commit", spy_decorator(AIOKafkaConsumer.commit) @@ -54,7 +95,7 @@ async def handler(msg: KafkaMessage): await asyncio.wait( ( asyncio.create_task( - full_broker.publish( + consume_broker.publish( "hello", queue, ) @@ -67,21 +108,49 @@ async def handler(msg: KafkaMessage): assert event.is_set() + @pytest.mark.asyncio() + async def test_manual_partition_consume( + self, + queue: str, + event: asyncio.Event, + ): + consume_broker = self.get_broker() + + tp1 = TopicPartition(queue, partition=0) + + @consume_broker.subscriber(partitions=[tp1]) + async def handler_tp1(msg): + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + + await asyncio.wait( + ( + asyncio.create_task(br.publish("hello", queue, partition=0)), + asyncio.create_task(event.wait()), + ), + timeout=10, + ) + + assert event.is_set() + @pytest.mark.asyncio() @pytest.mark.slow() async def test_consume_ack_manual( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue, group_id="test", auto_commit=False) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, group_id="test", auto_commit=False) async def handler(msg: KafkaMessage): await msg.ack() event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object( AIOKafkaConsumer, "commit", spy_decorator(AIOKafkaConsumer.commit) @@ -89,7 +158,7 @@ async def handler(msg: KafkaMessage): await asyncio.wait( ( asyncio.create_task( - full_broker.publish( + br.publish( "hello", queue, ) @@ -107,16 +176,17 @@ async def handler(msg: KafkaMessage): async def test_consume_ack_raise( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue, group_id="test", auto_commit=False) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, group_id="test", auto_commit=False) async def handler(msg: KafkaMessage): event.set() raise AckMessage() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object( AIOKafkaConsumer, "commit", spy_decorator(AIOKafkaConsumer.commit) @@ -124,7 +194,7 @@ async def handler(msg: KafkaMessage): await asyncio.wait( ( asyncio.create_task( - full_broker.publish( + br.publish( "hello", queue, ) @@ -142,16 +212,17 @@ async def handler(msg: KafkaMessage): async def test_nack( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue, group_id="test", auto_commit=False) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, group_id="test", auto_commit=False) async def handler(msg: KafkaMessage): await msg.nack() event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object( AIOKafkaConsumer, "commit", spy_decorator(AIOKafkaConsumer.commit) @@ -159,7 +230,7 @@ async def handler(msg: KafkaMessage): await asyncio.wait( ( asyncio.create_task( - full_broker.publish( + br.publish( "hello", queue, ) @@ -177,29 +248,32 @@ async def handler(msg: KafkaMessage): async def test_consume_no_ack( self, queue: str, - full_broker: KafkaBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue, group_id="test", no_ack=True) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, group_id="test", no_ack=True) async def handler(msg: KafkaMessage): event.set() - await full_broker.start() - with patch.object( - AIOKafkaConsumer, "commit", spy_decorator(AIOKafkaConsumer.commit) - ) as m: - await asyncio.wait( - ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) + async with self.patch_broker(consume_broker) as br: + await br.start() + + with patch.object( + AIOKafkaConsumer, "commit", spy_decorator(AIOKafkaConsumer.commit) + ) as m: + await asyncio.wait( + ( + asyncio.create_task( + br.publish( + "hello", + queue, + ) + ), + asyncio.create_task(event.wait()), ), - asyncio.create_task(event.wait()), - ), - timeout=10, - ) - m.mock.assert_not_called() + timeout=10, + ) + m.mock.assert_not_called() - assert event.is_set() + assert event.is_set() diff --git a/tests/brokers/kafka/test_publish.py b/tests/brokers/kafka/test_publish.py index 2aee2ad0ca..e913e3c638 100644 --- a/tests/brokers/kafka/test_publish.py +++ b/tests/brokers/kafka/test_publish.py @@ -8,18 +8,23 @@ @pytest.mark.kafka() class TestPublish(BrokerPublishTestcase): + def get_broker(self, apply_types: bool = False): + return KafkaBroker(apply_types=apply_types) + @pytest.mark.asyncio() - async def test_publish_batch(self, queue: str, broker: KafkaBroker): + async def test_publish_batch(self, queue: str): + pub_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=2) - @broker.subscriber(queue) + @pub_broker.subscriber(queue) async def handler(msg): await msgs_queue.put(msg) - async with broker: - await broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() - await broker.publish_batch(1, "hi", topic=queue) + await br.publish_batch(1, "hi", topic=queue) result, _ = await asyncio.wait( ( @@ -32,17 +37,19 @@ async def handler(msg): assert {1, "hi"} == {r.result() for r in result} @pytest.mark.asyncio() - async def test_batch_publisher_manual(self, queue: str, broker: KafkaBroker): + async def test_batch_publisher_manual(self, queue: str): + pub_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=2) - @broker.subscriber(queue) + @pub_broker.subscriber(queue) async def handler(msg): await msgs_queue.put(msg) - publisher = broker.publisher(queue, batch=True) + publisher = pub_broker.publisher(queue, batch=True) - async with broker: - await broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() await publisher.publish(1, "hi") @@ -57,22 +64,24 @@ async def handler(msg): assert {1, "hi"} == {r.result() for r in result} @pytest.mark.asyncio() - async def test_batch_publisher_decorator(self, queue: str, broker: KafkaBroker): + async def test_batch_publisher_decorator(self, queue: str): + pub_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=2) - @broker.subscriber(queue) + @pub_broker.subscriber(queue) async def handler(msg): await msgs_queue.put(msg) - @broker.publisher(queue, batch=True) - @broker.subscriber(queue + "1") + @pub_broker.publisher(queue, batch=True) + @pub_broker.subscriber(queue + "1") async def pub(m): return 1, "hi" - async with broker: - await broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() - await broker.publish("", queue + "1") + await br.publish("", queue + "1") result, _ = await asyncio.wait( ( diff --git a/tests/brokers/kafka/test_test_client.py b/tests/brokers/kafka/test_test_client.py index 7c72e6c525..a89ecff707 100644 --- a/tests/brokers/kafka/test_test_client.py +++ b/tests/brokers/kafka/test_test_client.py @@ -3,19 +3,78 @@ import pytest from faststream import BaseMiddleware -from faststream.kafka import KafkaBroker, TestKafkaBroker +from faststream.kafka import KafkaBroker, TestKafkaBroker, TopicPartition from tests.brokers.base.testclient import BrokerTestclientTestcase @pytest.mark.asyncio() class TestTestclient(BrokerTestclientTestcase): + test_class = TestKafkaBroker + + def get_broker(self, apply_types: bool = False): + return KafkaBroker(apply_types=apply_types) + + def patch_broker(self, broker: KafkaBroker) -> TestKafkaBroker: + return TestKafkaBroker(broker) + + async def test_partition_match( + self, + queue: str, + ): + broker = self.get_broker() + + @broker.subscriber(partitions=[TopicPartition(queue, 1)]) + async def m(msg): + pass + + async with self.patch_broker(broker) as br: + await br.publish("hello", queue) + + m.mock.assert_called_once_with("hello") + + async def test_partition_match_exect( + self, + queue: str, + ): + broker = self.get_broker() + + @broker.subscriber(partitions=[TopicPartition(queue, 1)]) + async def m(msg): + pass + + async with self.patch_broker(broker) as br: + await br.publish("hello", queue, partition=1) + + m.mock.assert_called_once_with("hello") + + async def test_partition_missmatch( + self, + queue: str, + ): + broker = self.get_broker() + + @broker.subscriber(partitions=[TopicPartition(queue, 1)]) + async def m(msg): + pass + + @broker.subscriber(queue) + async def m2(msg): + pass + + async with self.patch_broker(broker) as br: + await br.publish("hello", queue, partition=2) + + assert not m.mock.called + m2.mock.assert_called_once_with("hello") + @pytest.mark.kafka() async def test_with_real_testclient( self, - broker: KafkaBroker, queue: str, event: asyncio.Event, ): + broker = self.get_broker() + @broker.subscriber(queue) def subscriber(m): event.set() @@ -33,46 +92,49 @@ def subscriber(m): async def test_batch_pub_by_default_pub( self, - test_broker: KafkaBroker, queue: str, ): - @test_broker.subscriber(queue, batch=True) - async def m(): + broker = self.get_broker() + + @broker.subscriber(queue, batch=True) + async def m(msg): pass - await test_broker.start() - await test_broker.publish("hello", queue) - m.mock.assert_called_once_with(["hello"]) + async with TestKafkaBroker(broker) as br: + await br.publish("hello", queue) + m.mock.assert_called_once_with(["hello"]) async def test_batch_pub_by_pub_batch( self, - test_broker: KafkaBroker, queue: str, ): - @test_broker.subscriber(queue, batch=True) - async def m(): + broker = self.get_broker() + + @broker.subscriber(queue, batch=True) + async def m(msg): pass - await test_broker.start() - await test_broker.publish_batch("hello", topic=queue) - m.mock.assert_called_once_with(["hello"]) + async with TestKafkaBroker(broker) as br: + await br.publish_batch("hello", topic=queue) + m.mock.assert_called_once_with(["hello"]) async def test_batch_publisher_mock( self, - test_broker: KafkaBroker, queue: str, ): - publisher = test_broker.publisher(queue + "1", batch=True) + broker = self.get_broker() + + publisher = broker.publisher(queue + "1", batch=True) @publisher - @test_broker.subscriber(queue) - async def m(): + @broker.subscriber(queue) + async def m(msg): return 1, 2, 3 - await test_broker.start() - await test_broker.publish("hello", queue) - m.mock.assert_called_once_with("hello") - publisher.mock.assert_called_once_with([1, 2, 3]) + async with TestKafkaBroker(broker) as br: + await br.publish("hello", queue) + m.mock.assert_called_once_with("hello") + publisher.mock.assert_called_once_with([1, 2, 3]) async def test_respect_middleware(self, queue): routes = [] diff --git a/tests/brokers/nats/__init__.py b/tests/brokers/nats/__init__.py new file mode 100644 index 0000000000..87ead90ee6 --- /dev/null +++ b/tests/brokers/nats/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("nats") diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index cd496b1266..733797c150 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -13,24 +13,26 @@ @pytest.mark.nats() class TestConsume(BrokerRealConsumeTestcase): + def get_broker(self, apply_types: bool = False) -> NatsBroker: + return NatsBroker(apply_types=apply_types) + async def test_consume_js( self, queue: str, - consume_broker: NatsBroker, stream: JStream, event: asyncio.Event, ): + consume_broker = self.get_broker() + @consume_broker.subscriber(queue, stream=stream) def subscriber(m): event.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task( - consume_broker.publish("hello", queue, stream=stream.name) - ), + asyncio.create_task(br.publish("hello", queue, stream=stream.name)), asyncio.create_task(event.wait()), ), timeout=3, @@ -41,11 +43,12 @@ def subscriber(m): async def test_consume_pull( self, queue: str, - consume_broker: NatsBroker, stream: JStream, event: asyncio.Event, mock, ): + consume_broker = self.get_broker() + @consume_broker.subscriber( queue, stream=stream, @@ -55,26 +58,29 @@ def subscriber(m): mock(m) event.set() - await consume_broker.start() - await asyncio.wait( - ( - asyncio.create_task(consume_broker.publish("hello", queue)), - asyncio.create_task(event.wait()), - ), - timeout=3, - ) + async with self.patch_broker(consume_broker) as br: + await br.start() - assert event.is_set() - mock.assert_called_once_with("hello") + await asyncio.wait( + ( + asyncio.create_task(br.publish("hello", queue)), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with("hello") async def test_consume_batch( self, queue: str, - consume_broker: NatsBroker, stream: JStream, event: asyncio.Event, mock, ): + consume_broker = self.get_broker() + @consume_broker.subscriber( queue, stream=stream, @@ -84,41 +90,39 @@ def subscriber(m): mock(m) event.set() - await consume_broker.start() - await asyncio.wait( - ( - asyncio.create_task(consume_broker.publish(b"hello", queue)), - asyncio.create_task(event.wait()), - ), - timeout=3, - ) + async with self.patch_broker(consume_broker) as br: + await br.start() - assert event.is_set() - mock.assert_called_once_with([b"hello"]) + await asyncio.wait( + ( + asyncio.create_task(br.publish(b"hello", queue)), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with([b"hello"]) - @pytest.mark.asyncio() async def test_consume_ack( self, queue: str, - full_broker: NatsBroker, event: asyncio.Event, stream: JStream, ): - @full_broker.subscriber(queue, stream=stream) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, stream=stream) async def handler(msg: NatsMessage): event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object(Msg, "ack", spy_decorator(Msg.ack)) as m: await asyncio.wait( ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) - ), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -127,30 +131,26 @@ async def handler(msg: NatsMessage): assert event.is_set() - @pytest.mark.asyncio() async def test_consume_ack_manual( self, queue: str, - full_broker: NatsBroker, event: asyncio.Event, stream: JStream, ): - @full_broker.subscriber(queue, stream=stream) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, stream=stream) async def handler(msg: NatsMessage): await msg.ack() event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object(Msg, "ack", spy_decorator(Msg.ack)) as m: await asyncio.wait( ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) - ), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -159,30 +159,26 @@ async def handler(msg: NatsMessage): assert event.is_set() - @pytest.mark.asyncio() async def test_consume_ack_raise( self, queue: str, - full_broker: NatsBroker, event: asyncio.Event, stream: JStream, ): - @full_broker.subscriber(queue, stream=stream) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, stream=stream) async def handler(msg: NatsMessage): event.set() raise AckMessage() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object(Msg, "ack", spy_decorator(Msg.ack)) as m: await asyncio.wait( ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) - ), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -191,30 +187,26 @@ async def handler(msg: NatsMessage): assert event.is_set() - @pytest.mark.asyncio() async def test_nack( self, queue: str, - full_broker: NatsBroker, event: asyncio.Event, stream: JStream, ): - @full_broker.subscriber(queue, stream=stream) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, stream=stream) async def handler(msg: NatsMessage): await msg.nack() event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object(Msg, "nak", spy_decorator(Msg.nak)) as m: await asyncio.wait( ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) - ), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -223,29 +215,66 @@ async def handler(msg: NatsMessage): assert event.is_set() - @pytest.mark.asyncio() async def test_consume_no_ack( - self, queue: str, full_broker: NatsBroker, event: asyncio.Event + self, + queue: str, + event: asyncio.Event, ): - @full_broker.subscriber(queue, no_ack=True) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, no_ack=True) async def handler(msg: NatsMessage): event.set() - await full_broker.start() - with patch.object(Msg, "ack", spy_decorator(Msg.ack)) as m: + async with self.patch_broker(consume_broker) as br: + await br.start() + + with patch.object(Msg, "ack", spy_decorator(Msg.ack)) as m: + await asyncio.wait( + ( + asyncio.create_task(br.publish("hello", queue)), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + m.mock.assert_not_called() + + assert event.is_set() + + async def test_consume_batch_headers( + self, + queue: str, + stream: JStream, + event: asyncio.Event, + mock, + ): + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(1, batch=True), + ) + def subscriber(m, msg: NatsMessage): + check = all( + ( + msg.headers, + [msg.headers] == msg.batch_headers, + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task( - full_broker.publish( - "hello", - queue, - ) - ), + asyncio.create_task(br.publish("", queue, headers={"custom": "1"})), asyncio.create_task(event.wait()), ), timeout=3, ) - m.mock.assert_not_called() assert event.is_set() @@ -280,6 +309,7 @@ async def handler(m): assert event.is_set() mock.assert_called_with(b"world") + @pytest.mark.asyncio() async def test_consume_os( self, queue: str, full_broker: NatsBroker, event: asyncio.Event @@ -307,3 +337,5 @@ async def handler(filename: str): ) assert event.is_set() + assert event.is_set() + mock.assert_called_once_with(True) diff --git a/tests/brokers/nats/test_fastapi.py b/tests/brokers/nats/test_fastapi.py index 66f9206ed1..237fe5f81e 100644 --- a/tests/brokers/nats/test_fastapi.py +++ b/tests/brokers/nats/test_fastapi.py @@ -14,6 +14,32 @@ class TestRouter(FastAPITestcase): router_class = NatsRouter + async def test_path( + self, + queue: str, + event: asyncio.Event, + mock: MagicMock, + ): + router = NatsRouter() + + @router.subscriber("in.{name}") + def subscriber(msg: str, name: str): + mock(msg=msg, name=name) + event.set() + + async with router.broker: + await router.broker.start() + await asyncio.wait( + ( + asyncio.create_task(router.broker.publish("hello", "in.john")), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(msg="hello", name="john") + async def test_consume_batch( self, queue: str, diff --git a/tests/brokers/nats/test_publish.py b/tests/brokers/nats/test_publish.py index cfe7c74000..0f4aa8581d 100644 --- a/tests/brokers/nats/test_publish.py +++ b/tests/brokers/nats/test_publish.py @@ -1,5 +1,6 @@ import pytest +from faststream.nats import NatsBroker from tests.brokers.base.publish import BrokerPublishTestcase @@ -7,28 +8,5 @@ class TestPublish(BrokerPublishTestcase): """Test publish method of NATS broker.""" - @pytest.mark.asyncio() - async def test_stream_publish( - self, - queue: str, - test_broker, - ): - @test_broker.subscriber(queue, stream="test") - async def m(): ... - - await test_broker.start() - await test_broker.publish("Hi!", queue, stream="test") - m.mock.assert_called_once_with("Hi!") - - @pytest.mark.asyncio() - async def test_wrong_stream_publish( - self, - queue: str, - test_broker, - ): - @test_broker.subscriber(queue) - async def m(): ... - - await test_broker.start() - await test_broker.publish("Hi!", queue, stream="test") - assert not m.mock.called + def get_broker(self, apply_types: bool = False) -> NatsBroker: + return NatsBroker(apply_types=apply_types) diff --git a/tests/brokers/nats/test_router.py b/tests/brokers/nats/test_router.py index c2b0f44228..a0951a06d1 100644 --- a/tests/brokers/nats/test_router.py +++ b/tests/brokers/nats/test_router.py @@ -2,6 +2,7 @@ import pytest +from faststream import Path from faststream.nats import NatsPublisher, NatsRoute, NatsRouter from tests.brokers.base.router import RouterLocalTestcase, RouterTestcase @@ -12,6 +13,96 @@ class TestRouter(RouterTestcase): route_class = NatsRoute publisher_class = NatsPublisher + async def test_router_path( + self, + event, + mock, + router: NatsRouter, + pub_broker, + ): + @router.subscriber("in.{name}.{id}") + async def h( + name: str = Path(), + id: int = Path("id"), + ): + event.set() + mock(name=name, id=id) + + pub_broker._is_apply_types = True + pub_broker.include_router(router) + + await pub_broker.start() + + await pub_broker.publish( + "", + "in.john.2", + rpc=True, + ) + + assert event.is_set() + mock.assert_called_once_with(name="john", id=2) + + async def test_router_path_with_prefix( + self, + event, + mock, + router: NatsRouter, + pub_broker, + ): + router.prefix = "test." + + @router.subscriber("in.{name}.{id}") + async def h( + name: str = Path(), + id: int = Path("id"), + ): + event.set() + mock(name=name, id=id) + + pub_broker._is_apply_types = True + pub_broker.include_router(router) + + await pub_broker.start() + + await pub_broker.publish( + "", + "test.in.john.2", + rpc=True, + ) + + assert event.is_set() + mock.assert_called_once_with(name="john", id=2) + + async def test_router_delay_handler_path( + self, + event, + mock, + router: NatsRouter, + pub_broker, + ): + async def h( + name: str = Path(), + id: int = Path("id"), + ): + event.set() + mock(name=name, id=id) + + r = type(router)(handlers=(self.route_class(h, subject="in.{name}.{id}"),)) + + pub_broker._is_apply_types = True + pub_broker.include_router(r) + + await pub_broker.start() + + await pub_broker.publish( + "", + "in.john.2", + rpc=True, + ) + + assert event.is_set() + mock.assert_called_once_with(name="john", id=2) + async def test_delayed_handlers_with_queue( self, event, diff --git a/tests/brokers/nats/test_rpc.py b/tests/brokers/nats/test_rpc.py index 7c0bd18f06..9675883c2b 100644 --- a/tests/brokers/nats/test_rpc.py +++ b/tests/brokers/nats/test_rpc.py @@ -6,8 +6,13 @@ @pytest.mark.nats() class TestRPC(BrokerRPCTestcase, ReplyAndConsumeForbidden): + def get_broker(self, apply_types: bool = False) -> NatsBroker: + return NatsBroker(apply_types=apply_types) + @pytest.mark.asyncio() - async def test_rpc_js(self, queue: str, rpc_broker: NatsBroker, stream: JStream): + async def test_rpc_js(self, queue: str, stream: JStream): + rpc_broker = self.get_broker() + @rpc_broker.subscriber(queue, stream=stream) async def m(m): # pragma: no cover return "1" diff --git a/tests/brokers/nats/test_test_client.py b/tests/brokers/nats/test_test_client.py index 8190e27509..ebbd1c7887 100644 --- a/tests/brokers/nats/test_test_client.py +++ b/tests/brokers/nats/test_test_client.py @@ -3,19 +3,68 @@ import pytest from faststream import BaseMiddleware +from faststream.exceptions import SetupError from faststream.nats import JStream, NatsBroker, PullSub, TestNatsBroker from tests.brokers.base.testclient import BrokerTestclientTestcase @pytest.mark.asyncio() class TestTestclient(BrokerTestclientTestcase): + test_class = TestNatsBroker + + def get_broker(self, apply_types: bool = False) -> NatsBroker: + return NatsBroker(apply_types=apply_types) + + def patch_broker(self, broker: NatsBroker) -> TestNatsBroker: + return TestNatsBroker(broker) + + @pytest.mark.asyncio() + async def test_stream_publish( + self, + queue: str, + ): + pub_broker = NatsBroker(apply_types=False) + + @pub_broker.subscriber(queue, stream="test") + async def m(msg): ... + + async with TestNatsBroker(pub_broker) as br: + await br.publish("Hi!", queue, stream="test") + m.mock.assert_called_once_with("Hi!") + + @pytest.mark.asyncio() + async def test_wrong_stream_publish( + self, + queue: str, + ): + pub_broker = NatsBroker(apply_types=False) + + @pub_broker.subscriber(queue) + async def m(msg): ... + + async with TestNatsBroker(pub_broker) as br: + await br.publish("Hi!", queue, stream="test") + assert not m.mock.called + + @pytest.mark.asyncio() + async def test_rpc_conflicts_reply(self, queue): + async with TestNatsBroker(NatsBroker()) as br: + with pytest.raises(SetupError): + await br.publish( + "", + queue, + rpc=True, + reply_to="response", + ) + @pytest.mark.nats() async def test_with_real_testclient( self, - broker: NatsBroker, queue: str, event: asyncio.Event, ): + broker = self.get_broker() + @broker.subscriber(queue) def subscriber(m): event.set() @@ -79,76 +128,92 @@ async def h2(): ... assert len(routes) == 2 async def test_js_subscriber_mock( - self, queue: str, test_broker: NatsBroker, stream: JStream + self, + queue: str, + stream: JStream, ): - @test_broker.subscriber(queue, stream=stream) - async def m(): + broker = self.get_broker() + + @broker.subscriber(queue, stream=stream) + async def m(msg): pass - await test_broker.start() - await test_broker.publish("hello", queue, stream=stream.name) - m.mock.assert_called_once_with("hello") + async with TestNatsBroker(broker) as br: + await br.publish("hello", queue, stream=stream.name) + m.mock.assert_called_once_with("hello") async def test_js_publisher_mock( - self, queue: str, test_broker: NatsBroker, stream: JStream + self, + queue: str, + stream: JStream, ): - publisher = test_broker.publisher(queue + "resp") + broker = self.get_broker() + + publisher = broker.publisher(queue + "resp") @publisher - @test_broker.subscriber(queue, stream=stream) - async def m(): + @broker.subscriber(queue, stream=stream) + async def m(msg): return "response" - await test_broker.start() - await test_broker.publish("hello", queue, stream=stream.name) - publisher.mock.assert_called_with("response") + async with TestNatsBroker(broker) as br: + await br.publish("hello", queue, stream=stream.name) + publisher.mock.assert_called_with("response") - async def test_any_subject_routing(self, test_broker: NatsBroker): - @test_broker.subscriber("test.*.subj.*") - def subscriber(): ... + async def test_any_subject_routing(self): + broker = self.get_broker() - await test_broker.start() - await test_broker.publish("hello", "test.a.subj.b") - subscriber.mock.assert_called_once_with("hello") + @broker.subscriber("test.*.subj.*") + def subscriber(msg): ... - async def test_ending_subject_routing(self, test_broker: NatsBroker): - @test_broker.subscriber("test.>") - def subscriber(): ... + async with TestNatsBroker(broker) as br: + await br.publish("hello", "test.a.subj.b") + subscriber.mock.assert_called_once_with("hello") - await test_broker.start() - await test_broker.publish("hello", "test.a.subj.b") - subscriber.mock.assert_called_once_with("hello") + async def test_ending_subject_routing(self): + broker = self.get_broker() - async def test_mixed_subject_routing(self, test_broker: NatsBroker): - @test_broker.subscriber("*.*.subj.>") - def subscriber(): ... + @broker.subscriber("test.>") + def subscriber(msg): ... - await test_broker.start() - await test_broker.publish("hello", "test.a.subj.b.c") - subscriber.mock.assert_called_once_with("hello") + async with TestNatsBroker(broker) as br: + await br.publish("hello", "test.a.subj.b") + subscriber.mock.assert_called_once_with("hello") + + async def test_mixed_subject_routing(self): + broker = self.get_broker() + + @broker.subscriber("*.*.subj.>") + def subscriber(msg): ... + + async with TestNatsBroker(broker) as br: + await br.publish("hello", "test.a.subj.b.c") + subscriber.mock.assert_called_once_with("hello") async def test_consume_pull( self, queue: str, - test_broker: NatsBroker, stream: JStream, ): - @test_broker.subscriber(queue, stream=stream, pull_sub=PullSub(1)) + broker = self.get_broker() + + @broker.subscriber(queue, stream=stream, pull_sub=PullSub(1)) def subscriber(m): ... - await test_broker.start() - await test_broker.publish("hello", queue) - subscriber.mock.assert_called_once_with("hello") + async with TestNatsBroker(broker) as br: + await br.publish("hello", queue) + subscriber.mock.assert_called_once_with("hello") async def test_consume_batch( self, queue: str, - test_broker: NatsBroker, stream: JStream, event: asyncio.Event, mock, ): - @test_broker.subscriber( + broker = self.get_broker() + + @broker.subscriber( queue, stream=stream, pull_sub=PullSub(1, batch=True), @@ -157,6 +222,6 @@ def subscriber(m): mock(m) event.set() - await test_broker.start() - await test_broker.publish("hello", queue) - subscriber.mock.assert_called_once_with(["hello"]) + async with TestNatsBroker(broker) as br: + await br.publish("hello", queue) + subscriber.mock.assert_called_once_with(["hello"]) diff --git a/tests/brokers/rabbit/__init__.py b/tests/brokers/rabbit/__init__.py index e69de29bb2..ebec43fcd5 100644 --- a/tests/brokers/rabbit/__init__.py +++ b/tests/brokers/rabbit/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aio_pika") diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index 56ef7b9bae..30b5cab321 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -13,24 +13,28 @@ @pytest.mark.rabbit() class TestConsume(BrokerRealConsumeTestcase): + def get_broker(self, apply_types: bool = False) -> RabbitBroker: + return RabbitBroker(apply_types=apply_types) + @pytest.mark.asyncio() async def test_consume_from_exchange( self, queue: str, exchange: RabbitExchange, - broker: RabbitBroker, event: asyncio.Event, ): - @broker.subscriber(queue=queue, exchange=exchange, retry=1) + consume_broker = self.get_broker() + + @consume_broker.subscriber(queue=queue, exchange=exchange, retry=1) def h(m): event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( asyncio.create_task( - broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task(event.wait()), ), @@ -44,13 +48,11 @@ async def test_consume_with_get_old( self, queue: str, exchange: RabbitExchange, - broker: RabbitBroker, event: asyncio.Event, ): - await broker.declare_queue(RabbitQueue(queue)) - await broker.declare_exchange(exchange) + consume_broker = self.get_broker() - @broker.subscriber( + @consume_broker.subscriber( queue=RabbitQueue(name=queue, passive=True), exchange=RabbitExchange(name=exchange.name, passive=True), retry=True, @@ -58,13 +60,19 @@ async def test_consume_with_get_old( def h(m): event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.declare_queue(RabbitQueue(queue)) + await br.declare_exchange(exchange) + + await br.start() + await asyncio.wait( ( asyncio.create_task( - broker.publish( - Message(b"hello"), queue=queue, exchange=exchange.name + br.publish( + Message(b"hello"), + queue=queue, + exchange=exchange.name, ) ), asyncio.create_task(event.wait()), @@ -79,22 +87,24 @@ async def test_consume_ack( self, queue: str, exchange: RabbitExchange, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue=queue, exchange=exchange, retry=1) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue=queue, exchange=exchange, retry=1) async def handler(msg: RabbitMessage): event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object( IncomingMessage, "ack", spy_decorator(IncomingMessage.ack) ) as m: await asyncio.wait( ( asyncio.create_task( - full_broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task(event.wait()), ), @@ -109,23 +119,25 @@ async def test_consume_manual_ack( self, queue: str, exchange: RabbitExchange, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue=queue, exchange=exchange, retry=1) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue=queue, exchange=exchange, retry=1) async def handler(msg: RabbitMessage): await msg.ack() event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object( IncomingMessage, "ack", spy_decorator(IncomingMessage.ack) ) as m: await asyncio.wait( ( asyncio.create_task( - full_broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task(event.wait()), ), @@ -139,25 +151,27 @@ async def test_consume_exception_ack( self, queue: str, exchange: RabbitExchange, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue=queue, exchange=exchange, retry=1) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue=queue, exchange=exchange, retry=1) async def handler(msg: RabbitMessage): try: raise AckMessage() finally: event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object( IncomingMessage, "ack", spy_decorator(IncomingMessage.ack) ) as m: await asyncio.wait( ( asyncio.create_task( - full_broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task(event.wait()), ), @@ -171,24 +185,26 @@ async def test_consume_manual_nack( self, queue: str, exchange: RabbitExchange, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue=queue, exchange=exchange, retry=1) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue=queue, exchange=exchange, retry=1) async def handler(msg: RabbitMessage): await msg.nack() event.set() raise ValueError() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object( IncomingMessage, "nack", spy_decorator(IncomingMessage.nack) ) as m: await asyncio.wait( ( asyncio.create_task( - full_broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task(event.wait()), ), @@ -202,25 +218,27 @@ async def test_consume_exception_nack( self, queue: str, exchange: RabbitExchange, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue=queue, exchange=exchange, retry=1) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue=queue, exchange=exchange, retry=1) async def handler(msg: RabbitMessage): try: raise NackMessage() finally: event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object( IncomingMessage, "nack", spy_decorator(IncomingMessage.nack) ) as m: await asyncio.wait( ( asyncio.create_task( - full_broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task(event.wait()), ), @@ -234,24 +252,26 @@ async def test_consume_manual_reject( self, queue: str, exchange: RabbitExchange, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue=queue, exchange=exchange, retry=1) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue=queue, exchange=exchange, retry=1) async def handler(msg: RabbitMessage): await msg.reject() event.set() raise ValueError() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object( IncomingMessage, "reject", spy_decorator(IncomingMessage.reject) ) as m: await asyncio.wait( ( asyncio.create_task( - full_broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task(event.wait()), ), @@ -265,25 +285,27 @@ async def test_consume_exception_reject( self, queue: str, exchange: RabbitExchange, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue=queue, exchange=exchange, retry=1) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue=queue, exchange=exchange, retry=1) async def handler(msg: RabbitMessage): try: raise RejectMessage() finally: event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object( IncomingMessage, "reject", spy_decorator(IncomingMessage.reject) ) as m: await asyncio.wait( ( asyncio.create_task( - full_broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task(event.wait()), ), @@ -296,18 +318,20 @@ async def handler(msg: RabbitMessage): async def test_consume_skip_message( self, queue: str, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue) async def handler(msg: RabbitMessage): try: raise SkipMessage() finally: event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + with patch.object( IncomingMessage, "reject", spy_decorator(IncomingMessage.reject) ) as m, patch.object( @@ -317,7 +341,7 @@ async def handler(msg: RabbitMessage): ) as m2: await asyncio.wait( ( - asyncio.create_task(full_broker.publish("hello", queue)), + asyncio.create_task(br.publish("hello", queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -333,26 +357,29 @@ async def test_consume_no_ack( self, queue: str, exchange: RabbitExchange, - full_broker: RabbitBroker, event: asyncio.Event, ): - @full_broker.subscriber(queue, exchange=exchange, retry=1, no_ack=True) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, exchange=exchange, retry=1, no_ack=True) async def handler(msg: RabbitMessage): event.set() - await full_broker.start() - with patch.object( - IncomingMessage, "ack", spy_decorator(IncomingMessage.ack) - ) as m: - await asyncio.wait( - ( - asyncio.create_task( - full_broker.publish("hello", queue=queue, exchange=exchange) + async with self.patch_broker(consume_broker) as br: + await br.start() + + with patch.object( + IncomingMessage, "ack", spy_decorator(IncomingMessage.ack) + ) as m: + await asyncio.wait( + ( + asyncio.create_task( + br.publish("hello", queue=queue, exchange=exchange) + ), + asyncio.create_task(event.wait()), ), - asyncio.create_task(event.wait()), - ), - timeout=3, - ) - m.mock.assert_not_called() + timeout=3, + ) + m.mock.assert_not_called() - assert event.is_set() + assert event.is_set() diff --git a/tests/brokers/rabbit/test_fastapi.py b/tests/brokers/rabbit/test_fastapi.py index c79d0ff028..9248b2312f 100644 --- a/tests/brokers/rabbit/test_fastapi.py +++ b/tests/brokers/rabbit/test_fastapi.py @@ -1,3 +1,6 @@ +import asyncio +from unittest.mock import MagicMock + import pytest from faststream.rabbit import ExchangeType, RabbitExchange, RabbitQueue @@ -10,13 +13,51 @@ class TestRouter(FastAPITestcase): router_class = RabbitRouter + @pytest.mark.asyncio() + async def test_path( + self, + queue: str, + event: asyncio.Event, + mock: MagicMock, + ): + router = RabbitRouter() + + @router.subscriber( + RabbitQueue( + queue, + routing_key="in.{name}", + ), + RabbitExchange( + queue + "1", + type=ExchangeType.TOPIC, + ), + ) + def subscriber(msg: str, name: str): + mock(msg=msg, name=name) + event.set() + + async with router.broker: + await router.broker.start() + await asyncio.wait( + ( + asyncio.create_task( + router.broker.publish("hello", "in.john", queue + "1") + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(msg="hello", name="john") + +@pytest.mark.asyncio() class TestRouterLocal(FastAPILocalTestcase): router_class = RabbitRouter broker_test = staticmethod(TestRabbitBroker) build_message = staticmethod(build_message) - @pytest.mark.asyncio() async def test_path(self): router = self.router_class() diff --git a/tests/brokers/rabbit/test_publish.py b/tests/brokers/rabbit/test_publish.py index 0cf9672866..97be60f066 100644 --- a/tests/brokers/rabbit/test_publish.py +++ b/tests/brokers/rabbit/test_publish.py @@ -11,15 +11,21 @@ @pytest.mark.rabbit() class TestPublish(BrokerPublishTestcase): + def get_broker(self, apply_types: bool = False) -> RabbitBroker: + return RabbitBroker(apply_types=apply_types) + @pytest.mark.asyncio() async def test_reply_config( self, - pub_broker: RabbitBroker, queue: str, event, mock, ): - @pub_broker.subscriber(queue + "reply") + pub_broker = self.get_broker() + + reply_queue = queue + "reply" + + @pub_broker.subscriber(reply_queue) async def reply_handler(m): event.set() mock(m) @@ -28,20 +34,18 @@ async def reply_handler(m): async def handler(m): return m - async with pub_broker: + async with self.patch_broker(pub_broker) as br: with patch.object( AioPikaFastProducer, "publish", spy_decorator(AioPikaFastProducer.publish), ) as m: - await pub_broker.start() + await br.start() await asyncio.wait( ( asyncio.create_task( - pub_broker.publish( - "Hello!", queue, reply_to=queue + "reply" - ) + br.publish("Hello!", queue, reply_to=reply_queue) ), asyncio.create_task(event.wait()), ), diff --git a/tests/brokers/rabbit/test_router.py b/tests/brokers/rabbit/test_router.py index 5af8c581e4..ac14d3372d 100644 --- a/tests/brokers/rabbit/test_router.py +++ b/tests/brokers/rabbit/test_router.py @@ -2,8 +2,11 @@ import pytest +from faststream import Path from faststream.rabbit import ( + ExchangeType, RabbitBroker, + RabbitExchange, RabbitPublisher, RabbitQueue, RabbitRoute, @@ -18,6 +21,92 @@ class TestRouter(RouterTestcase): route_class = RabbitRoute publisher_class = RabbitPublisher + async def test_router_path( + self, + queue, + event, + mock, + router, + pub_broker, + ): + @router.subscriber( + RabbitQueue( + queue, + routing_key="in.{name}.{id}", + ), + RabbitExchange( + queue + "1", + type=ExchangeType.TOPIC, + ), + ) + async def h( + name: str = Path(), + id: int = Path("id"), + ): + event.set() + mock(name=name, id=id) + + pub_broker._is_apply_types = True + pub_broker.include_router(router) + + await pub_broker.start() + + await pub_broker.publish( + "", + "in.john.2", + queue + "1", + rpc=True, + ) + + assert event.is_set() + mock.assert_called_once_with(name="john", id=2) + + async def test_router_delay_handler_path( + self, + queue, + event, + mock, + router, + pub_broker, + ): + async def h( + name: str = Path(), + id: int = Path("id"), + ): + event.set() + mock(name=name, id=id) + + r = type(router)( + handlers=( + self.route_class( + h, + queue=RabbitQueue( + queue, + routing_key="in.{name}.{id}", + ), + exchange=RabbitExchange( + queue + "1", + type=ExchangeType.TOPIC, + ), + ), + ) + ) + + pub_broker._is_apply_types = True + pub_broker.include_router(r) + + await pub_broker.start() + + await pub_broker.publish( + "", + "in.john.2", + queue + "1", + rpc=True, + ) + + assert event.is_set() + mock.assert_called_once_with(name="john", id=2) + async def test_queue_obj( self, router: RabbitRouter, @@ -50,6 +139,39 @@ def subscriber(m): assert event.is_set() + async def test_queue_obj_with_routing_key( + self, + router: RabbitRouter, + broker: RabbitBroker, + queue: str, + event: asyncio.Event, + ): + router.prefix = "test/" + + r_queue = RabbitQueue("useless", routing_key=f"{queue}1") + exchange = RabbitExchange(f"{queue}exch") + + @router.subscriber(r_queue, exchange=exchange) + def subscriber(m): + event.set() + + broker.include_router(router) + + async with broker: + await broker.start() + + await asyncio.wait( + ( + asyncio.create_task( + broker.publish("hello", f"test/{queue}1", exchange=exchange) + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + async def test_delayed_handlers_with_queue( self, event: asyncio.Event, diff --git a/tests/brokers/rabbit/test_rpc.py b/tests/brokers/rabbit/test_rpc.py index 76b8fc2b68..d0bd80cab7 100644 --- a/tests/brokers/rabbit/test_rpc.py +++ b/tests/brokers/rabbit/test_rpc.py @@ -1,8 +1,10 @@ import pytest +from faststream.rabbit import RabbitBroker from tests.brokers.base.rpc import BrokerRPCTestcase, ReplyAndConsumeForbidden @pytest.mark.rabbit() class TestRPC(BrokerRPCTestcase, ReplyAndConsumeForbidden): - pass + def get_broker(self, apply_types: bool = False) -> RabbitBroker: + return RabbitBroker(apply_types=apply_types) diff --git a/tests/brokers/rabbit/test_test_client.py b/tests/brokers/rabbit/test_test_client.py index b5f32f0de6..e07cbd88c0 100644 --- a/tests/brokers/rabbit/test_test_client.py +++ b/tests/brokers/rabbit/test_test_client.py @@ -1,9 +1,9 @@ import asyncio -from unittest.mock import Mock import pytest from faststream import BaseMiddleware +from faststream.exceptions import SetupError from faststream.rabbit import ( ExchangeType, RabbitBroker, @@ -18,13 +18,34 @@ @pytest.mark.asyncio() class TestTestclient(BrokerTestclientTestcase): + test_class = TestRabbitBroker + + def get_broker(self, apply_types: bool = False) -> RabbitBroker: + return RabbitBroker(apply_types=apply_types) + + def patch_broker(self, broker: RabbitBroker) -> RabbitBroker: + return TestRabbitBroker(broker) + + async def test_rpc_conflicts_reply(self, queue): + broker = self.get_broker() + + async with TestRabbitBroker(broker) as br: + with pytest.raises(SetupError): + await br.publish( + "", + queue, + rpc=True, + reply_to="response", + ) + @pytest.mark.rabbit() async def test_with_real_testclient( self, - broker: RabbitBroker, queue: str, event: asyncio.Event, ): + broker = self.get_broker() + @broker.subscriber(queue) def subscriber(m): event.set() @@ -42,86 +63,92 @@ def subscriber(m): async def test_direct( self, - test_broker: RabbitBroker, queue: str, ): - @test_broker.subscriber(queue) + broker = self.get_broker() + + @broker.subscriber(queue) async def handler(m): return 1 - @test_broker.subscriber(queue + "1", exchange="test") + @broker.subscriber(queue + "1", exchange="test") async def handler2(m): return 2 - await test_broker.start() - assert await test_broker.publish("", queue, rpc=True) == 1 - assert ( - await test_broker.publish("", queue + "1", exchange="test", rpc=True) == 2 - ) - assert None is await test_broker.publish("", exchange="test2", rpc=True) + async with TestRabbitBroker(broker) as br: + await br.start() + assert await br.publish("", queue, rpc=True) == 1 + assert await br.publish("", queue + "1", exchange="test", rpc=True) == 2 + assert None is await br.publish("", exchange="test2", rpc=True) async def test_fanout( self, - test_broker: RabbitBroker, queue: str, + mock, ): - mock = Mock() + broker = self.get_broker() exch = RabbitExchange("test", type=ExchangeType.FANOUT) - @test_broker.subscriber(queue, exchange=exch) + @broker.subscriber(queue, exchange=exch) async def handler(m): mock() - await test_broker.start() - await test_broker.publish("", exchange=exch, rpc=True) - assert None is await test_broker.publish("", exchange="test2", rpc=True) + async with TestRabbitBroker(broker) as br: + await br.publish("", exchange=exch, rpc=True) + + assert None is await br.publish("", exchange="test2", rpc=True) + + assert mock.call_count == 1 - assert mock.call_count == 1 + async def test_any_topic_routing(self): + broker = self.get_broker() - async def test_any_topic_routing(self, test_broker: RabbitBroker): exch = RabbitExchange("test", type=ExchangeType.TOPIC) - @test_broker.subscriber( + @broker.subscriber( RabbitQueue("test", routing_key="test.*.subj.*"), exchange=exch, ) - def subscriber(): ... + def subscriber(msg): ... - await test_broker.start() - await test_broker.publish("hello", "test.a.subj.b", exchange=exch) - subscriber.mock.assert_called_once_with("hello") + async with TestRabbitBroker(broker) as br: + await br.publish("hello", "test.a.subj.b", exchange=exch) + subscriber.mock.assert_called_once_with("hello") + + async def test_ending_topic_routing(self): + broker = self.get_broker() - async def test_ending_topic_routing(self, test_broker: RabbitBroker): exch = RabbitExchange("test", type=ExchangeType.TOPIC) - @test_broker.subscriber( + @broker.subscriber( RabbitQueue("test", routing_key="test.#"), exchange=exch, ) - def subscriber(): ... + def subscriber(msg): ... + + async with TestRabbitBroker(broker) as br: + await br.publish("hello", "test.a.subj.b", exchange=exch) + subscriber.mock.assert_called_once_with("hello") - await test_broker.start() - await test_broker.publish("hello", "test.a.subj.b", exchange=exch) - subscriber.mock.assert_called_once_with("hello") + async def test_mixed_topic_routing(self): + broker = self.get_broker() - async def test_mixed_topic_routing(self, test_broker: RabbitBroker): exch = RabbitExchange("test", type=ExchangeType.TOPIC) - @test_broker.subscriber( + @broker.subscriber( RabbitQueue("test", routing_key="*.*.subj.#"), exchange=exch, ) - def subscriber(): ... + def subscriber(msg): ... - await test_broker.start() - await test_broker.publish("hello", "test.a.subj.b.c", exchange=exch) - subscriber.mock.assert_called_once_with("hello") + async with TestRabbitBroker(broker) as br: + await br.publish("hello", "test.a.subj.b.c", exchange=exch) + subscriber.mock.assert_called_once_with("hello") + + async def test_header(self): + broker = self.get_broker() - async def test_header( - self, - test_broker: RabbitBroker, - ): q1 = RabbitQueue( "test-queue-2", bind_arguments={"key": 2, "key2": 2, "x-match": "any"}, @@ -136,74 +163,65 @@ async def test_header( ) exch = RabbitExchange("exchange", type=ExchangeType.HEADERS) - @test_broker.subscriber(q2, exch) - async def handler2(): + @broker.subscriber(q2, exch) + async def handler2(msg): return 2 - @test_broker.subscriber(q1, exch) - async def handler(): + @broker.subscriber(q1, exch) + async def handler(msg): return 1 - @test_broker.subscriber(q3, exch) - async def handler3(): + @broker.subscriber(q3, exch) + async def handler3(msg): return 3 - await test_broker.start() - assert ( - await test_broker.publish( - exchange=exch, rpc=True, headers={"key": 2, "key2": 2} + async with TestRabbitBroker(broker) as br: + assert ( + await br.publish(exchange=exch, rpc=True, headers={"key": 2, "key2": 2}) + == 2 ) - == 2 - ) - assert ( - await test_broker.publish(exchange=exch, rpc=True, headers={"key": 2}) == 1 - ) - assert await test_broker.publish(exchange=exch, rpc=True, headers={}) == 3 + assert await br.publish(exchange=exch, rpc=True, headers={"key": 2}) == 1 + assert await br.publish(exchange=exch, rpc=True, headers={}) == 3 async def test_consume_manual_ack( self, queue: str, exchange: RabbitExchange, - test_broker: RabbitBroker, ): + broker = self.get_broker(apply_types=True) + consume = asyncio.Event() consume2 = asyncio.Event() consume3 = asyncio.Event() - @test_broker.subscriber(queue=queue, exchange=exchange, retry=1) + @broker.subscriber(queue=queue, exchange=exchange, retry=1) async def handler(msg: RabbitMessage): await msg.raw_message.ack() consume.set() - @test_broker.subscriber(queue=queue + "1", exchange=exchange, retry=1) + @broker.subscriber(queue=queue + "1", exchange=exchange, retry=1) async def handler2(msg: RabbitMessage): await msg.raw_message.nack() consume2.set() raise ValueError() - @test_broker.subscriber(queue=queue + "2", exchange=exchange, retry=1) + @broker.subscriber(queue=queue + "2", exchange=exchange, retry=1) async def handler3(msg: RabbitMessage): await msg.raw_message.reject() consume3.set() raise ValueError() - await test_broker.start() - async with test_broker: - await test_broker.start() + async with TestRabbitBroker(broker) as br: await asyncio.wait( ( asyncio.create_task( - test_broker.publish("hello", queue=queue, exchange=exchange) + br.publish("hello", queue=queue, exchange=exchange) ), asyncio.create_task( - test_broker.publish( - "hello", queue=queue + "1", exchange=exchange - ) + br.publish("hello", queue=queue + "1", exchange=exchange) ), asyncio.create_task( - test_broker.publish( - "hello", queue=queue + "2", exchange=exchange - ) + br.publish("hello", queue=queue + "2", exchange=exchange) ), asyncio.create_task(consume.wait()), asyncio.create_task(consume2.wait()), @@ -227,10 +245,10 @@ async def on_receive(self) -> None: broker = RabbitBroker(middlewares=(Middleware,)) @broker.subscriber(queue) - async def h1(): ... + async def h1(msg): ... @broker.subscriber(queue + "1") - async def h2(): ... + async def h2(msg): ... async with TestRabbitBroker(broker) as br: await br.publish("", queue) @@ -250,10 +268,10 @@ async def on_receive(self) -> None: broker = RabbitBroker(middlewares=(Middleware,)) @broker.subscriber(queue) - async def h1(): ... + async def h1(msg): ... @broker.subscriber(queue + "1") - async def h2(): ... + async def h2(msg): ... async with TestRabbitBroker(broker, with_real=True) as br: await br.publish("", queue) diff --git a/tests/brokers/redis/__init__.py b/tests/brokers/redis/__init__.py index e69de29bb2..4752ef19b1 100644 --- a/tests/brokers/redis/__init__.py +++ b/tests/brokers/redis/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("redis") diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index 176fd0965f..071467c449 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -13,25 +13,28 @@ @pytest.mark.redis() @pytest.mark.asyncio() class TestConsume(BrokerRealConsumeTestcase): + def get_broker(self, apply_types: bool = False): + return RedisBroker(apply_types=apply_types) + async def test_consume_native( self, - consume_broker: RedisBroker, event: asyncio.Event, mock: MagicMock, queue: str, ): + consume_broker = self.get_broker() + @consume_broker.subscriber(queue) async def handler(msg): mock(msg) event.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task( - consume_broker._connection.publish(queue, "hello") - ), + asyncio.create_task(br._connection.publish(queue, "hello")), asyncio.create_task(event.wait()), ), timeout=3, @@ -41,20 +44,22 @@ async def handler(msg): async def test_pattern_with_path( self, - consume_broker: RedisBroker, event: asyncio.Event, mock: MagicMock, ): + consume_broker = self.get_broker() + @consume_broker.subscriber("test.{name}") async def handler(msg): mock(msg) event.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task(consume_broker.publish("hello", "test.name")), + asyncio.create_task(br.publish("hello", "test.name")), asyncio.create_task(event.wait()), ), timeout=3, @@ -64,20 +69,22 @@ async def handler(msg): async def test_pattern_without_path( self, - consume_broker: RedisBroker, event: asyncio.Event, mock: MagicMock, ): + consume_broker = self.get_broker() + @consume_broker.subscriber(PubSub("test.*", pattern=True)) async def handler(msg): mock(msg) event.set() - async with consume_broker: - await consume_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task(consume_broker.publish("hello", "test.name")), + asyncio.create_task(br.publish("hello", "test.name")), asyncio.create_task(event.wait()), ), timeout=3, @@ -89,23 +96,31 @@ async def handler(msg): @pytest.mark.redis() @pytest.mark.asyncio() class TestConsumeList: + def get_broker(self, apply_types: bool = False): + return RedisBroker(apply_types=apply_types) + + def patch_broker(self, broker): + return broker + async def test_consume_list( self, - broker: RedisBroker, event: asyncio.Event, queue: str, mock: MagicMock, ): - @broker.subscriber(list=queue) + consume_broker = self.get_broker() + + @consume_broker.subscriber(list=queue) async def handler(msg): mock(msg) event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task(broker.publish("hello", list=queue)), + asyncio.create_task(br.publish("hello", list=queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -115,21 +130,23 @@ async def handler(msg): async def test_consume_list_native( self, - broker: RedisBroker, event: asyncio.Event, queue: str, mock: MagicMock, ): - @broker.subscriber(list=queue) + consume_broker = self.get_broker() + + @consume_broker.subscriber(list=queue) async def handler(msg): mock(msg) event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task(broker._connection.rpush(queue, "hello")), + asyncio.create_task(br._connection.rpush(queue, "hello")), asyncio.create_task(event.wait()), ), timeout=3, @@ -138,47 +155,107 @@ async def handler(msg): mock.assert_called_once_with(b"hello") @pytest.mark.slow() - async def test_consume_list_batch_with_one(self, queue: str, broker: RedisBroker): - msgs_queue = asyncio.Queue(maxsize=1) + async def test_consume_list_batch_with_one( + self, + queue: str, + event: asyncio.Event, + mock, + ): + consume_broker = self.get_broker() - @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=1)) + @consume_broker.subscriber( + list=ListSub(queue, batch=True, polling_interval=0.01) + ) async def handler(msg): - await msgs_queue.put(msg) + mock(msg) + event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( + ( + asyncio.create_task(br.publish("hi", list=queue)), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) - await broker.publish("hi", list=queue) + assert event.is_set() + mock.assert_called_once_with(["hi"]) - result, _ = await asyncio.wait( - (asyncio.create_task(msgs_queue.get()),), + @pytest.mark.slow() + async def test_consume_list_batch_headers( + self, + queue: str, + event: asyncio.Event, + mock, + ): + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( + list=ListSub(queue, batch=True, polling_interval=0.01) + ) + def subscriber(m, msg: RedisMessage): + check = all( + ( + msg.headers, + msg.headers["correlation_id"] + == msg.batch_headers[0]["correlation_id"], + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( + ( + asyncio.create_task( + br.publish("", list=queue, headers={"custom": "1"}) + ), + asyncio.create_task(event.wait()), + ), timeout=3, ) - assert ["hi"] == [r.result()[0] for r in result] + assert event.is_set() + mock.assert_called_once_with(True) @pytest.mark.slow() - async def test_consume_list_batch(self, queue: str, broker: RedisBroker): + async def test_consume_list_batch( + self, + queue: str, + ): + consume_broker = self.get_broker(apply_types=True) + msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=1)) + @consume_broker.subscriber( + list=ListSub(queue, batch=True, polling_interval=0.01) + ) async def handler(msg): await msgs_queue.put(msg) - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() - await broker.publish_batch(1, "hi", list=queue) + await br.publish_batch(1, "hi", list=queue) result, _ = await asyncio.wait( (asyncio.create_task(msgs_queue.get()),), timeout=3, ) - assert [{1, "hi"}] == [set(r.result()) for r in result] + assert [{1, "hi"}] == [set(r.result()) for r in result] @pytest.mark.slow() - async def test_consume_list_batch_complex(self, queue: str, broker: RedisBroker): + async def test_consume_list_batch_complex( + self, + queue: str, + ): + consume_broker = self.get_broker(apply_types=True) + from pydantic import BaseModel class Data(BaseModel): @@ -189,15 +266,16 @@ def __hash__(self): msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=1)) + @consume_broker.subscriber( + list=ListSub(queue, batch=True, polling_interval=0.01) + ) async def handler(msg: List[Data]): await msgs_queue.put(msg) - broker._is_apply_types = True - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() - await broker.publish_batch(Data(m="hi"), Data(m="again"), list=queue) + await br.publish_batch(Data(m="hi"), Data(m="again"), list=queue) result, _ = await asyncio.wait( (asyncio.create_task(msgs_queue.get()),), @@ -207,17 +285,24 @@ async def handler(msg: List[Data]): assert [{Data(m="hi"), Data(m="again")}] == [set(r.result()) for r in result] @pytest.mark.slow() - async def test_consume_list_batch_native(self, queue: str, broker: RedisBroker): + async def test_consume_list_batch_native( + self, + queue: str, + ): + consume_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=1)) + @consume_broker.subscriber( + list=ListSub(queue, batch=True, polling_interval=0.01) + ) async def handler(msg): await msgs_queue.put(msg) - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() - await broker._connection.rpush(queue, 1, "hi") + await br._connection.rpush(queue, 1, "hi") result, _ = await asyncio.wait( (asyncio.create_task(msgs_queue.get()),), @@ -230,25 +315,32 @@ async def handler(msg): @pytest.mark.redis() @pytest.mark.asyncio() class TestConsumeStream: + def get_broker(self, apply_types: bool = False): + return RedisBroker(apply_types=apply_types) + + def patch_broker(self, broker): + return broker + @pytest.mark.slow() async def test_consume_stream( self, - broker: RedisBroker, event: asyncio.Event, mock: MagicMock, queue, ): - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000)) + consume_broker = self.get_broker() + + @consume_broker.subscriber(stream=StreamSub(queue, polling_interval=10)) async def handler(msg): mock(msg) event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(broker.publish("hello", stream=queue)), + asyncio.create_task(br.publish("hello", stream=queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -259,23 +351,24 @@ async def handler(msg): @pytest.mark.slow() async def test_consume_stream_native( self, - broker: RedisBroker, event: asyncio.Event, mock: MagicMock, queue, ): - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000)) + consume_broker = self.get_broker() + + @consume_broker.subscriber(stream=StreamSub(queue, polling_interval=10)) async def handler(msg): mock(msg) event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( asyncio.create_task( - broker._connection.xadd(queue, {"message": "hello"}) + br._connection.xadd(queue, {"message": "hello"}) ), asyncio.create_task(event.wait()), ), @@ -287,22 +380,25 @@ async def handler(msg): @pytest.mark.slow() async def test_consume_stream_batch( self, - broker: RedisBroker, event: asyncio.Event, mock: MagicMock, queue, ): - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000, batch=True)) + consume_broker = self.get_broker() + + @consume_broker.subscriber( + stream=StreamSub(queue, polling_interval=10, batch=True) + ) async def handler(msg): mock(msg) event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( - asyncio.create_task(broker.publish("hello", stream=queue)), + asyncio.create_task(br.publish("hello", stream=queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -310,12 +406,52 @@ async def handler(msg): mock.assert_called_once_with(["hello"]) + @pytest.mark.slow() + async def test_consume_stream_batch_headers( + self, + queue: str, + event: asyncio.Event, + mock, + ): + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( + stream=StreamSub(queue, polling_interval=10, batch=True) + ) + def subscriber(m, msg: RedisMessage): + check = all( + ( + msg.headers, + msg.headers["correlation_id"] + == msg.batch_headers[0]["correlation_id"], + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( + ( + asyncio.create_task( + br.publish("", stream=queue, headers={"custom": "1"}) + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(True) + @pytest.mark.slow() async def test_consume_stream_batch_complex( self, - broker: RedisBroker, queue, ): + consume_broker = self.get_broker(apply_types=True) + from pydantic import BaseModel class Data(BaseModel): @@ -323,15 +459,16 @@ class Data(BaseModel): msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000, batch=True)) + @consume_broker.subscriber( + stream=StreamSub(queue, polling_interval=10, batch=True) + ) async def handler(msg: List[Data]): await msgs_queue.put(msg) - broker._is_apply_types = True - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() - await broker.publish(Data(m="hi"), stream=queue) + await br.publish(Data(m="hi"), stream=queue) result, _ = await asyncio.wait( (asyncio.create_task(msgs_queue.get()),), @@ -343,23 +480,26 @@ async def handler(msg: List[Data]): @pytest.mark.slow() async def test_consume_stream_batch_native( self, - broker: RedisBroker, event: asyncio.Event, mock: MagicMock, queue, ): - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000, batch=True)) + consume_broker = self.get_broker() + + @consume_broker.subscriber( + stream=StreamSub(queue, polling_interval=10, batch=True) + ) async def handler(msg): mock(msg) event.set() - async with broker: - await broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() await asyncio.wait( ( asyncio.create_task( - broker._connection.xadd(queue, {"message": "hello"}) + br._connection.xadd(queue, {"message": "hello"}) ), asyncio.create_task(event.wait()), ), @@ -371,43 +511,50 @@ async def handler(msg): async def test_consume_group( self, queue: str, - full_broker: RedisBroker, ): - @full_broker.subscriber(stream=StreamSub(queue, group="group", consumer=queue)) + consume_broker = self.get_broker() + + @consume_broker.subscriber( + stream=StreamSub(queue, group="group", consumer=queue) + ) async def handler(msg: RedisMessage): ... - assert next(iter(full_broker._subscribers.values())).last_id == "$" + assert next(iter(consume_broker._subscribers.values())).last_id == "$" async def test_consume_group_with_last_id( self, queue: str, - full_broker: RedisBroker, ): - @full_broker.subscriber( + consume_broker = self.get_broker() + + @consume_broker.subscriber( stream=StreamSub(queue, group="group", consumer=queue, last_id="0") ) async def handler(msg: RedisMessage): ... - assert next(iter(full_broker._subscribers.values())).last_id == "0" + assert next(iter(consume_broker._subscribers.values())).last_id == "0" async def test_consume_nack( self, queue: str, - full_broker: RedisBroker, event: asyncio.Event, ): - @full_broker.subscriber(stream=StreamSub(queue, group="group", consumer=queue)) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( + stream=StreamSub(queue, group="group", consumer=queue) + ) async def handler(msg: RedisMessage): event.set() await msg.nack() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object(Redis, "xack", spy_decorator(Redis.xack)) as m: await asyncio.wait( ( - asyncio.create_task(full_broker.publish("hello", stream=queue)), + asyncio.create_task(br.publish("hello", stream=queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -420,20 +567,23 @@ async def handler(msg: RedisMessage): async def test_consume_ack( self, queue: str, - full_broker: RedisBroker, event: asyncio.Event, ): - @full_broker.subscriber(stream=StreamSub(queue, group="group", consumer=queue)) + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber( + stream=StreamSub(queue, group="group", consumer=queue) + ) async def handler(msg: RedisMessage): event.set() - async with full_broker: - await full_broker.start() + async with self.patch_broker(consume_broker) as br: + await br.start() with patch.object(Redis, "xack", spy_decorator(Redis.xack)) as m: await asyncio.wait( ( - asyncio.create_task(full_broker.publish("hello", stream=queue)), + asyncio.create_task(br.publish("hello", stream=queue)), asyncio.create_task(event.wait()), ), timeout=3, diff --git a/tests/brokers/redis/test_fastapi.py b/tests/brokers/redis/test_fastapi.py index 4b8fcf5f96..36e95d1a29 100644 --- a/tests/brokers/redis/test_fastapi.py +++ b/tests/brokers/redis/test_fastapi.py @@ -14,6 +14,32 @@ class TestRouter(FastAPITestcase): router_class = RedisRouter + async def test_path( + self, + queue: str, + event: asyncio.Event, + mock: Mock, + ): + router = RedisRouter() + + @router.subscriber("in.{name}") + def subscriber(msg: str, name: str): + mock(msg=msg, name=name) + event.set() + + async with router.broker: + await router.broker.start() + await asyncio.wait( + ( + asyncio.create_task(router.broker.publish("hello", "in.john")), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(msg="hello", name="john") + async def test_connection_params(self, settings): broker = RedisRouter( host="fake-host", port=6377 @@ -60,7 +86,7 @@ async def test_consume_stream( ): router = RedisRouter() - @router.subscriber(stream=StreamSub(queue, polling_interval=3000)) + @router.subscriber(stream=StreamSub(queue, polling_interval=10)) async def handler(msg): mock(msg) event.set() @@ -88,7 +114,7 @@ async def test_consume_stream_batch( ): router = RedisRouter() - @router.subscriber(stream=StreamSub(queue, polling_interval=3000, batch=True)) + @router.subscriber(stream=StreamSub(queue, polling_interval=10, batch=True)) async def handler(msg: List[str]): mock(msg) event.set() diff --git a/tests/brokers/redis/test_publish.py b/tests/brokers/redis/test_publish.py index 6210c3ddde..2c1f2b96ff 100644 --- a/tests/brokers/redis/test_publish.py +++ b/tests/brokers/redis/test_publish.py @@ -12,16 +12,20 @@ @pytest.mark.redis() @pytest.mark.asyncio() class TestPublish(BrokerPublishTestcase): + def get_broker(self, apply_types: bool = False): + return RedisBroker(apply_types=apply_types) + async def test_list_publisher( self, queue: str, - pub_broker: RedisBroker, event: asyncio.Event, mock: MagicMock, ): + pub_broker = self.get_broker() + @pub_broker.subscriber(list=queue) @pub_broker.publisher(list=queue + "resp") - async def m(): + async def m(msg): return "" @pub_broker.subscriber(list=queue + "resp") @@ -29,11 +33,12 @@ async def resp(msg): event.set() mock(msg) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task(pub_broker.publish("", list=queue)), + asyncio.create_task(br.publish("", list=queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -42,17 +47,22 @@ async def resp(msg): assert event.is_set() mock.assert_called_once_with("") - async def test_list_publish_batch(self, queue: str, broker: RedisBroker): + async def test_list_publish_batch( + self, + queue: str, + ): + pub_broker = self.get_broker() + msgs_queue = asyncio.Queue(maxsize=2) - @broker.subscriber(list=queue) + @pub_broker.subscriber(list=queue) async def handler(msg): await msgs_queue.put(msg) - async with broker: - await broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() - await broker.publish_batch(1, "hi", list=queue) + await br.publish_batch(1, "hi", list=queue) result, _ = await asyncio.wait( ( @@ -67,15 +77,16 @@ async def handler(msg): async def test_batch_list_publisher( self, queue: str, - pub_broker: RedisBroker, event: asyncio.Event, mock: MagicMock, ): + pub_broker = self.get_broker() + batch_list = ListSub(queue + "resp", batch=True) @pub_broker.subscriber(list=queue) @pub_broker.publisher(list=batch_list) - async def m(): + async def m(msg): return 1, 2, 3 @pub_broker.subscriber(list=batch_list) @@ -83,11 +94,12 @@ async def resp(msg): event.set() mock(msg) - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task(pub_broker.publish("", list=queue)), + asyncio.create_task(br.publish("", list=queue)), asyncio.create_task(event.wait()), ), timeout=3, @@ -99,10 +111,11 @@ async def resp(msg): async def test_publisher_with_maxlen( self, queue: str, - pub_broker: RedisBroker, event: asyncio.Event, mock: MagicMock, ): + pub_broker = self.get_broker() + stream = StreamSub(queue + "resp", maxlen=1) @pub_broker.subscriber(stream=queue) @@ -116,11 +129,12 @@ async def resp(msg): mock(msg) with patch.object(Redis, "xadd", spy_decorator(Redis.xadd)) as m: - async with pub_broker: - await pub_broker.start() + async with self.patch_broker(pub_broker) as br: + await br.start() + await asyncio.wait( ( - asyncio.create_task(pub_broker.publish("hi", stream=queue)), + asyncio.create_task(br.publish("hi", stream=queue)), asyncio.create_task(event.wait()), ), timeout=3, diff --git a/tests/brokers/redis/test_router.py b/tests/brokers/redis/test_router.py index 67c01dc994..b67b56ad1f 100644 --- a/tests/brokers/redis/test_router.py +++ b/tests/brokers/redis/test_router.py @@ -2,6 +2,7 @@ import pytest +from faststream import Path from faststream.redis import RedisBroker, RedisPublisher, RedisRoute, RedisRouter from tests.brokers.base.router import RouterLocalTestcase, RouterTestcase @@ -18,6 +19,96 @@ class TestRouterLocal(RouterLocalTestcase): route_class = RedisRoute publisher_class = RedisPublisher + async def test_router_path( + self, + event, + mock, + router, + pub_broker, + ): + @router.subscriber("in.{name}.{id}") + async def h( + name: str = Path(), + id: int = Path("id"), + ): + event.set() + mock(name=name, id=id) + + pub_broker._is_apply_types = True + pub_broker.include_router(router) + + await pub_broker.start() + + await pub_broker.publish( + "", + "in.john.2", + rpc=True, + ) + + assert event.is_set() + mock.assert_called_once_with(name="john", id=2) + + async def test_router_path_with_prefix( + self, + event, + mock, + router, + pub_broker, + ): + router.prefix = "test." + + @router.subscriber("in.{name}.{id}") + async def h( + name: str = Path(), + id: int = Path("id"), + ): + event.set() + mock(name=name, id=id) + + pub_broker._is_apply_types = True + pub_broker.include_router(router) + + await pub_broker.start() + + await pub_broker.publish( + "", + "test.in.john.2", + rpc=True, + ) + + assert event.is_set() + mock.assert_called_once_with(name="john", id=2) + + async def test_router_delay_handler_path( + self, + event, + mock, + router, + pub_broker, + ): + async def h( + name: str = Path(), + id: int = Path("id"), + ): + event.set() + mock(name=name, id=id) + + r = type(router)(handlers=(self.route_class(h, channel="in.{name}.{id}"),)) + + pub_broker._is_apply_types = True + pub_broker.include_router(r) + + await pub_broker.start() + + await pub_broker.publish( + "", + "in.john.2", + rpc=True, + ) + + assert event.is_set() + mock.assert_called_once_with(name="john", id=2) + async def test_delayed_channel_handlers( self, event: asyncio.Event, diff --git a/tests/brokers/redis/test_rpc.py b/tests/brokers/redis/test_rpc.py index 4006ef7d0f..c149d20d01 100644 --- a/tests/brokers/redis/test_rpc.py +++ b/tests/brokers/redis/test_rpc.py @@ -6,14 +6,20 @@ @pytest.mark.redis() class TestRPC(BrokerRPCTestcase, ReplyAndConsumeForbidden): + def get_broker(self, apply_types: bool = False): + return RedisBroker(apply_types=apply_types) + @pytest.mark.asyncio() - async def test_list_rpc(self, queue: str, rpc_broker: RedisBroker): + async def test_list_rpc(self, queue: str): + rpc_broker = self.get_broker() + @rpc_broker.subscriber(list=queue) async def m(m): # pragma: no cover return "1" - async with rpc_broker: - await rpc_broker.start() - r = await rpc_broker.publish("hello", list=queue, rpc_timeout=3, rpc=True) + async with self.patch_broker(rpc_broker) as br: + await br.start() + + r = await br.publish("hello", list=queue, rpc_timeout=3, rpc=True) assert r == "1" diff --git a/tests/brokers/redis/test_test_client.py b/tests/brokers/redis/test_test_client.py index ba87d4e685..ae6340ad7a 100644 --- a/tests/brokers/redis/test_test_client.py +++ b/tests/brokers/redis/test_test_client.py @@ -3,19 +3,39 @@ import pytest from faststream import BaseMiddleware +from faststream.exceptions import SetupError from faststream.redis import ListSub, RedisBroker, StreamSub, TestRedisBroker from tests.brokers.base.testclient import BrokerTestclientTestcase @pytest.mark.asyncio() class TestTestclient(BrokerTestclientTestcase): + test_class = TestRedisBroker + + def get_broker(self, apply_types: bool = False) -> RedisBroker: + return RedisBroker(apply_types=apply_types) + + def patch_broker(self, broker: RedisBroker) -> TestRedisBroker: + return TestRedisBroker(broker) + + async def test_rpc_conflicts_reply(self, queue): + async with TestRedisBroker(RedisBroker()) as br: + with pytest.raises(SetupError): + await br.publish( + "", + queue, + rpc=True, + reply_to="response", + ) + @pytest.mark.redis() async def test_with_real_testclient( self, - broker: RedisBroker, queue: str, event: asyncio.Event, ): + broker = self.get_broker() + @broker.subscriber(queue) def subscriber(m): event.set() @@ -78,127 +98,131 @@ async def h2(): ... assert len(routes) == 2 - async def test_pub_sub_pattern( - self, - test_broker: RedisBroker, - ): - @test_broker.subscriber("test.{name}") + async def test_pub_sub_pattern(self): + broker = self.get_broker() + + @broker.subscriber("test.{name}") async def handler(msg): return msg - await test_broker.start() - - assert await test_broker.publish(1, "test.name.useless", rpc=True) == 1 - handler.mock.assert_called_once_with(1) + async with self.patch_broker(broker) as br: + assert await br.publish(1, "test.name.useless", rpc=True) == 1 + handler.mock.assert_called_once_with(1) async def test_list( self, - test_broker: RedisBroker, queue: str, ): - @test_broker.subscriber(list=queue) + broker = self.get_broker() + + @broker.subscriber(list=queue) async def handler(msg): return msg - await test_broker.start() - - assert await test_broker.publish(1, list=queue, rpc=True) == 1 - handler.mock.assert_called_once_with(1) + async with self.patch_broker(broker) as br: + assert await br.publish(1, list=queue, rpc=True) == 1 + handler.mock.assert_called_once_with(1) async def test_batch_pub_by_default_pub( self, - test_broker: RedisBroker, queue: str, ): - @test_broker.subscriber(list=ListSub(queue, batch=True)) - async def m(): + broker = self.get_broker() + + @broker.subscriber(list=ListSub(queue, batch=True)) + async def m(msg): pass - await test_broker.start() - await test_broker.publish("hello", list=queue) - m.mock.assert_called_once_with(["hello"]) + async with self.patch_broker(broker) as br: + await br.publish("hello", list=queue) + m.mock.assert_called_once_with(["hello"]) async def test_batch_pub_by_pub_batch( self, - test_broker: RedisBroker, queue: str, ): - @test_broker.subscriber(list=ListSub(queue, batch=True)) - async def m(): + broker = self.get_broker() + + @broker.subscriber(list=ListSub(queue, batch=True)) + async def m(msg): pass - await test_broker.start() - await test_broker.publish_batch("hello", list=queue) - m.mock.assert_called_once_with(["hello"]) + async with self.patch_broker(broker) as br: + await br.publish_batch("hello", list=queue) + m.mock.assert_called_once_with(["hello"]) async def test_batch_publisher_mock( self, - test_broker: RedisBroker, queue: str, ): + broker = self.get_broker() + batch_list = ListSub(queue + "1", batch=True) - publisher = test_broker.publisher(list=batch_list) + publisher = broker.publisher(list=batch_list) @publisher - @test_broker.subscriber(queue) - async def m(): + @broker.subscriber(queue) + async def m(msg): return 1, 2, 3 - await test_broker.start() - await test_broker.publish("hello", queue) - m.mock.assert_called_once_with("hello") - publisher.mock.assert_called_once_with([1, 2, 3]) + async with self.patch_broker(broker) as br: + await br.publish("hello", queue) + m.mock.assert_called_once_with("hello") + publisher.mock.assert_called_once_with([1, 2, 3]) async def test_stream( self, - test_broker: RedisBroker, queue: str, ): - @test_broker.subscriber(stream=queue) + broker = self.get_broker() + + @broker.subscriber(stream=queue) async def handler(msg): return msg - await test_broker.start() - - assert await test_broker.publish(1, stream=queue, rpc=True) == 1 - handler.mock.assert_called_once_with(1) + async with self.patch_broker(broker) as br: + assert await br.publish(1, stream=queue, rpc=True) == 1 + handler.mock.assert_called_once_with(1) async def test_stream_batch_pub_by_default_pub( self, - test_broker: RedisBroker, queue: str, ): - @test_broker.subscriber(stream=StreamSub(queue, batch=True)) - async def m(): + broker = self.get_broker() + + @broker.subscriber(stream=StreamSub(queue, batch=True)) + async def m(msg): pass - await test_broker.start() - await test_broker.publish("hello", stream=queue) - m.mock.assert_called_once_with(["hello"]) + async with self.patch_broker(broker) as br: + await br.publish("hello", stream=queue) + m.mock.assert_called_once_with(["hello"]) async def test_stream_publisher( self, - test_broker: RedisBroker, queue: str, ): + broker = self.get_broker() + batch_stream = StreamSub(queue + "1") - publisher = test_broker.publisher(stream=batch_stream) + publisher = broker.publisher(stream=batch_stream) @publisher - @test_broker.subscriber(queue) - async def m(): + @broker.subscriber(queue) + async def m(msg): return 1, 2, 3 - await test_broker.start() - await test_broker.publish("hello", queue) - m.mock.assert_called_once_with("hello") - publisher.mock.assert_called_once_with([1, 2, 3]) + async with self.patch_broker(broker) as br: + await br.publish("hello", queue) + m.mock.assert_called_once_with("hello") + publisher.mock.assert_called_once_with([1, 2, 3]) async def test_publish_to_none( self, - test_broker: RedisBroker, queue: str, ): - await test_broker.start() - with pytest.raises(ValueError): # noqa: PT011 - await test_broker.publish("hello") + broker = self.get_broker() + + async with self.patch_broker(broker) as br: + with pytest.raises(ValueError): # noqa: PT011 + await br.publish("hello") diff --git a/tests/cli/supervisors/test_base_reloader.py b/tests/cli/supervisors/test_base_reloader.py index c143d39c9f..2a1c2fd6ed 100644 --- a/tests/cli/supervisors/test_base_reloader.py +++ b/tests/cli/supervisors/test_base_reloader.py @@ -14,7 +14,7 @@ def should_restart(self) -> bool: return True -def empty(): +def empty(*args, **kwargs): pass diff --git a/tests/conftest.py b/tests/conftest.py index d15d9cb7a2..92778c660a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import asyncio from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 import pytest from typer.testing import CliRunner @@ -18,6 +19,11 @@ def pytest_collection_modifyitems(items): item.add_marker("all") +@pytest.fixture() +def queue(): + return str(uuid4()) + + @pytest.fixture() def event(): return asyncio.Event() diff --git a/tests/docs/confluent/__init__.py b/tests/docs/confluent/__init__.py index e69de29bb2..c4a1803708 100644 --- a/tests/docs/confluent/__init__.py +++ b/tests/docs/confluent/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("confluent_kafka") diff --git a/tests/docs/kafka/__init__.py b/tests/docs/kafka/__init__.py index e69de29bb2..bd6bc708fc 100644 --- a/tests/docs/kafka/__init__.py +++ b/tests/docs/kafka/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aiokafka") diff --git a/tests/docs/nats/__init__.py b/tests/docs/nats/__init__.py index e69de29bb2..87ead90ee6 100644 --- a/tests/docs/nats/__init__.py +++ b/tests/docs/nats/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("nats") diff --git a/tests/docs/rabbit/__init__.py b/tests/docs/rabbit/__init__.py index e69de29bb2..ebec43fcd5 100644 --- a/tests/docs/rabbit/__init__.py +++ b/tests/docs/rabbit/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aio_pika") diff --git a/tests/docs/redis/__init__.py b/tests/docs/redis/__init__.py index e69de29bb2..4752ef19b1 100644 --- a/tests/docs/redis/__init__.py +++ b/tests/docs/redis/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("redis") diff --git a/tests/marks.py b/tests/marks.py index 4a41446988..80bb1cde5c 100644 --- a/tests/marks.py +++ b/tests/marks.py @@ -4,12 +4,22 @@ from faststream._compat import PYDANTIC_V2 -python39 = pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9+") +python39 = pytest.mark.skipif( + sys.version_info < (3, 9), + reason="requires python3.9+", +) python310 = pytest.mark.skipif( - sys.version_info < (3, 10), reason="requires python3.10+" + sys.version_info < (3, 10), + reason="requires python3.10+", ) -pydantic_v1 = pytest.mark.skipif(PYDANTIC_V2, reason="requires PydanticV2") +pydantic_v1 = pytest.mark.skipif( + PYDANTIC_V2, + reason="requires PydanticV2", +) -pydantic_v2 = pytest.mark.skipif(not PYDANTIC_V2, reason="requires PydanticV1") +pydantic_v2 = pytest.mark.skipif( + not PYDANTIC_V2, + reason="requires PydanticV1", +) diff --git a/tests/opentelemetry/__init__.py b/tests/opentelemetry/__init__.py new file mode 100644 index 0000000000..75763c2fee --- /dev/null +++ b/tests/opentelemetry/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("opentelemetry") diff --git a/tests/opentelemetry/basic.py b/tests/opentelemetry/basic.py new file mode 100644 index 0000000000..794a09ee6d --- /dev/null +++ b/tests/opentelemetry/basic.py @@ -0,0 +1,357 @@ +import asyncio +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, cast +from unittest.mock import Mock + +import pytest +from dirty_equals import IsFloat, IsUUID +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics._internal.point import Metric +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import Span, TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.semconv.trace import SpanAttributes as SpanAttr +from opentelemetry.trace import SpanKind + +from faststream.broker.core.usecase import BrokerUsecase +from faststream.opentelemetry.consts import ( + ERROR_TYPE, + MESSAGING_DESTINATION_PUBLISH_NAME, +) +from faststream.opentelemetry.middleware import MessageAction as Action +from faststream.opentelemetry.middleware import TelemetryMiddleware + + +@pytest.mark.asyncio() +class LocalTelemetryTestcase: + messaging_system: str + include_messages_counters: bool + broker_class: Type[BrokerUsecase] + timeout: int = 3 + subscriber_kwargs: ClassVar[Dict[str, Any]] = {} + resource: Resource = Resource.create(attributes={"service.name": "faststream.test"}) + + telemetry_middleware_class: TelemetryMiddleware + + def patch_broker(self, broker: BrokerUsecase) -> BrokerUsecase: + return broker + + def destination_name(self, queue: str) -> str: + return queue + + @staticmethod + def get_spans(exporter: InMemorySpanExporter) -> List[Span]: + spans = cast(Tuple[Span, ...], exporter.get_finished_spans()) + return sorted(spans, key=lambda s: s.start_time) + + @staticmethod + def get_metrics( + reader: InMemoryMetricReader, + ) -> List[Metric]: + """Get sorted metrics. + + Return order: + - messaging.process.duration + - messaging.process.messages + - messaging.publish.duration + - messaging.publish.messages + """ + metrics = reader.get_metrics_data() + metrics = metrics.resource_metrics[0].scope_metrics[0].metrics + metrics = sorted(metrics, key=lambda m: m.name) + return cast(List[Metric], metrics) + + @pytest.fixture() + def tracer_provider(self) -> TracerProvider: + tracer_provider = TracerProvider(resource=self.resource) + return tracer_provider + + @pytest.fixture() + def trace_exporter(self, tracer_provider: TracerProvider) -> InMemorySpanExporter: + exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) + return exporter + + @pytest.fixture() + def metric_reader(self) -> InMemoryMetricReader: + return InMemoryMetricReader() + + @pytest.fixture() + def meter_provider(self, metric_reader: InMemoryMetricReader) -> MeterProvider: + return MeterProvider(metric_readers=(metric_reader,), resource=self.resource) + + def assert_span( + self, + span: Span, + action: str, + queue: str, + msg: str, + parent_span_id: Optional[str] = None, + ) -> None: + attrs = span.attributes + assert attrs[SpanAttr.MESSAGING_SYSTEM] == self.messaging_system, attrs[ + SpanAttr.MESSAGING_SYSTEM + ] + assert attrs[SpanAttr.MESSAGING_MESSAGE_CONVERSATION_ID] == IsUUID, attrs[ + SpanAttr.MESSAGING_MESSAGE_CONVERSATION_ID + ] + assert span.name == f"{self.destination_name(queue)} {action}", span.name + assert span.kind in (SpanKind.CONSUMER, SpanKind.PRODUCER), span.kind + + if span.kind == SpanKind.PRODUCER and action in (Action.CREATE, Action.PUBLISH): + assert attrs[SpanAttr.MESSAGING_DESTINATION_NAME] == queue, attrs[ + SpanAttr.MESSAGING_DESTINATION_NAME + ] + + if span.kind == SpanKind.CONSUMER and action in (Action.CREATE, Action.PROCESS): + assert attrs[MESSAGING_DESTINATION_PUBLISH_NAME] == queue, attrs[ + MESSAGING_DESTINATION_PUBLISH_NAME + ] + assert attrs[SpanAttr.MESSAGING_MESSAGE_ID] == IsUUID, attrs[ + SpanAttr.MESSAGING_MESSAGE_ID + ] + + if action == Action.PROCESS: + assert attrs[SpanAttr.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES] == len( + msg + ), attrs[SpanAttr.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES] + assert attrs[SpanAttr.MESSAGING_OPERATION] == action, attrs[ + SpanAttr.MESSAGING_OPERATION + ] + + if action == Action.PUBLISH: + assert attrs[SpanAttr.MESSAGING_OPERATION] == action, attrs[ + SpanAttr.MESSAGING_OPERATION + ] + + if parent_span_id: + assert span.parent.span_id == parent_span_id, span.parent.span_id + + def assert_metrics( + self, + metrics: List[Metric], + count: int = 1, + error_type: Optional[str] = None, + ) -> None: + if self.include_messages_counters: + assert len(metrics) == 4 + proc_dur, proc_msg, pub_dur, pub_msg = metrics + + assert proc_msg.data.data_points[0].value == count + assert pub_msg.data.data_points[0].value == count + + else: + assert len(metrics) == 2 + proc_dur, pub_dur = metrics + + if error_type: + assert proc_dur.data.data_points[0].attributes[ERROR_TYPE] == error_type + + assert proc_dur.data.data_points[0].count == 1 + assert proc_dur.data.data_points[0].sum == IsFloat + + assert pub_dur.data.data_points[0].count == 1 + assert pub_dur.data.data_points[0].sum == IsFloat + + async def test_subscriber_create_publish_process_span( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class(tracer_provider=tracer_provider) + broker = self.broker_class(middlewares=(mid,)) + + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(m): + mock(m) + event.set() + + broker = self.patch_broker(broker) + msg = "start" + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish(msg, queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + create, publish, process = self.get_spans(trace_exporter) + parent_span_id = create.context.span_id + + self.assert_span(create, Action.CREATE, queue, msg) + self.assert_span(publish, Action.PUBLISH, queue, msg, parent_span_id) + self.assert_span(process, Action.PROCESS, queue, msg, parent_span_id) + + assert event.is_set() + mock.assert_called_once_with(msg) + + async def test_chain_subscriber_publisher( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class(tracer_provider=tracer_provider) + broker = self.broker_class(middlewares=(mid,)) + + first_queue = queue + second_queue = queue + "2" + + @broker.subscriber(first_queue, **self.subscriber_kwargs) + @broker.publisher(second_queue) + async def handler1(m): + return m + + @broker.subscriber(second_queue, **self.subscriber_kwargs) + async def handler2(m): + mock(m) + event.set() + + broker = self.patch_broker(broker) + msg = "start" + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish(msg, queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + spans = self.get_spans(trace_exporter) + create, pub1, proc1, pub2, proc2 = spans + parent_span_id = create.context.span_id + + self.assert_span(create, Action.CREATE, first_queue, msg) + self.assert_span(pub1, Action.PUBLISH, first_queue, msg, parent_span_id) + self.assert_span(proc1, Action.PROCESS, first_queue, msg, parent_span_id) + self.assert_span(pub2, Action.PUBLISH, second_queue, msg, proc1.context.span_id) + self.assert_span(proc2, Action.PROCESS, second_queue, msg, parent_span_id) + + assert ( + create.start_time + < pub1.start_time + < proc1.start_time + < pub2.start_time + < proc2.start_time + ) + + assert event.is_set() + mock.assert_called_once_with(msg) + + async def test_no_trace_context_create_process_span( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class(tracer_provider=tracer_provider) + broker = self.broker_class(middlewares=(mid,)) + + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(m): + mock(m) + event.set() + + broker = self.patch_broker(broker) + msg = "start" + + async with broker: + await broker.start() + broker._middlewares = () + tasks = ( + asyncio.create_task(broker.publish(msg, queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + create, process = self.get_spans(trace_exporter) + parent_span_id = create.context.span_id + + self.assert_span(create, Action.CREATE, queue, msg) + self.assert_span(process, Action.PROCESS, queue, msg, parent_span_id) + + assert event.is_set() + mock.assert_called_once_with(msg) + + async def test_metrics( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + ): + mid = self.telemetry_middleware_class(meter_provider=meter_provider) + broker = self.broker_class(middlewares=(mid,)) + + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(m): + mock(m) + event.set() + + broker = self.patch_broker(broker) + msg = "start" + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish(msg, queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + metrics = self.get_metrics(metric_reader) + + self.assert_metrics(metrics) + + assert event.is_set() + mock.assert_called_once_with(msg) + + async def test_error_metrics( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + ): + mid = self.telemetry_middleware_class(meter_provider=meter_provider) + broker = self.broker_class(middlewares=(mid,)) + expected_value_type = "ValueError" + + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(m): + try: + raise ValueError + finally: + mock(m) + event.set() + + broker = self.patch_broker(broker) + msg = "start" + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish(msg, queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + metrics = self.get_metrics(metric_reader) + + self.assert_metrics(metrics, error_type=expected_value_type) + + assert event.is_set() + mock.assert_called_once_with(msg) diff --git a/tests/opentelemetry/confluent/__init__.py b/tests/opentelemetry/confluent/__init__.py new file mode 100644 index 0000000000..c4a1803708 --- /dev/null +++ b/tests/opentelemetry/confluent/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("confluent_kafka") diff --git a/tests/opentelemetry/confluent/test_confluent.py b/tests/opentelemetry/confluent/test_confluent.py new file mode 100644 index 0000000000..3877d488ba --- /dev/null +++ b/tests/opentelemetry/confluent/test_confluent.py @@ -0,0 +1,130 @@ +import asyncio +from typing import Any, ClassVar, Dict, Optional +from unittest.mock import Mock + +import pytest +from dirty_equals import IsStr, IsUUID +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.sdk.trace import Span, TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.semconv.trace import SpanAttributes as SpanAttr +from opentelemetry.trace import SpanKind + +from faststream.confluent import KafkaBroker +from faststream.confluent.opentelemetry import KafkaTelemetryMiddleware +from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME +from faststream.opentelemetry.middleware import MessageAction as Action +from tests.brokers.confluent.test_consume import TestConsume +from tests.brokers.confluent.test_publish import TestPublish + +from ..basic import LocalTelemetryTestcase + + +@pytest.mark.confluent() +class TestTelemetry(LocalTelemetryTestcase): + messaging_system = "kafka" + include_messages_counters = True + timeout: int = 10 + subscriber_kwargs: ClassVar[Dict[str, Any]] = {"auto_offset_reset": "earliest"} + broker_class = KafkaBroker + telemetry_middleware_class = KafkaTelemetryMiddleware + + def assert_span( + self, + span: Span, + action: str, + queue: str, + msg: str, + parent_span_id: Optional[str] = None, + ) -> None: + attrs = span.attributes + assert attrs[SpanAttr.MESSAGING_SYSTEM] == self.messaging_system + assert attrs[SpanAttr.MESSAGING_MESSAGE_CONVERSATION_ID] == IsUUID + assert span.name == f"{self.destination_name(queue)} {action}" + assert span.kind in (SpanKind.CONSUMER, SpanKind.PRODUCER) + + if span.kind == SpanKind.PRODUCER and action in (Action.CREATE, Action.PUBLISH): + assert attrs[SpanAttr.MESSAGING_DESTINATION_NAME] == queue + + if span.kind == SpanKind.CONSUMER and action in (Action.CREATE, Action.PROCESS): + assert attrs[MESSAGING_DESTINATION_PUBLISH_NAME] == queue + assert attrs[SpanAttr.MESSAGING_MESSAGE_ID] == IsStr(regex=r"0-.+") + assert attrs[SpanAttr.MESSAGING_KAFKA_DESTINATION_PARTITION] == 0 + assert attrs[SpanAttr.MESSAGING_KAFKA_MESSAGE_OFFSET] == 0 + + if action == Action.PROCESS: + assert attrs[SpanAttr.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES] == len(msg) + assert attrs[SpanAttr.MESSAGING_OPERATION] == action + + if action == Action.PUBLISH: + assert attrs[SpanAttr.MESSAGING_OPERATION] == action + + if parent_span_id: + assert span.parent.span_id == parent_span_id + + async def test_batch( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + expected_msg_count = 3 + + @broker.subscriber(queue, batch=True, **self.subscriber_kwargs) + async def handler(m): + mock(m) + event.set() + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish_batch(1, "hi", 3, topic=queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + metrics = self.get_metrics(metric_reader) + spans = self.get_spans(trace_exporter) + _, publish, process = spans + + assert ( + publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + assert ( + process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + self.assert_metrics(metrics, count=expected_msg_count) + + assert event.is_set() + mock.assert_called_once_with([1, "hi", 3]) + + +@pytest.mark.confluent() +class TestPublishWithTelemetry(TestPublish): + def get_broker(self, apply_types: bool = False): + return KafkaBroker( + middlewares=(KafkaTelemetryMiddleware(),), + apply_types=apply_types, + ) + + +@pytest.mark.confluent() +class TestConsumeWithTelemetry(TestConsume): + def get_broker(self, apply_types: bool = False): + return KafkaBroker( + middlewares=(KafkaTelemetryMiddleware(),), + apply_types=apply_types, + ) diff --git a/tests/opentelemetry/kafka/__init__.py b/tests/opentelemetry/kafka/__init__.py new file mode 100644 index 0000000000..bd6bc708fc --- /dev/null +++ b/tests/opentelemetry/kafka/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aiokafka") diff --git a/tests/opentelemetry/kafka/test_kafka.py b/tests/opentelemetry/kafka/test_kafka.py new file mode 100644 index 0000000000..2142825098 --- /dev/null +++ b/tests/opentelemetry/kafka/test_kafka.py @@ -0,0 +1,128 @@ +import asyncio +from typing import Optional +from unittest.mock import Mock + +import pytest +from dirty_equals import IsStr, IsUUID +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.sdk.trace import Span, TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.semconv.trace import SpanAttributes as SpanAttr +from opentelemetry.trace import SpanKind + +from faststream.kafka import KafkaBroker +from faststream.kafka.opentelemetry import KafkaTelemetryMiddleware +from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME +from faststream.opentelemetry.middleware import MessageAction as Action +from tests.brokers.kafka.test_consume import TestConsume +from tests.brokers.kafka.test_publish import TestPublish + +from ..basic import LocalTelemetryTestcase + + +@pytest.mark.kafka() +class TestTelemetry(LocalTelemetryTestcase): + messaging_system = "kafka" + include_messages_counters = True + broker_class = KafkaBroker + telemetry_middleware_class = KafkaTelemetryMiddleware + + def assert_span( + self, + span: Span, + action: str, + queue: str, + msg: str, + parent_span_id: Optional[str] = None, + ) -> None: + attrs = span.attributes + assert attrs[SpanAttr.MESSAGING_SYSTEM] == self.messaging_system + assert attrs[SpanAttr.MESSAGING_MESSAGE_CONVERSATION_ID] == IsUUID + assert span.name == f"{self.destination_name(queue)} {action}" + assert span.kind in (SpanKind.CONSUMER, SpanKind.PRODUCER) + + if span.kind == SpanKind.PRODUCER and action in (Action.CREATE, Action.PUBLISH): + assert attrs[SpanAttr.MESSAGING_DESTINATION_NAME] == queue + + if span.kind == SpanKind.CONSUMER and action in (Action.CREATE, Action.PROCESS): + assert attrs[MESSAGING_DESTINATION_PUBLISH_NAME] == queue + assert attrs[SpanAttr.MESSAGING_MESSAGE_ID] == IsStr(regex=r"0-.+") + assert attrs[SpanAttr.MESSAGING_KAFKA_DESTINATION_PARTITION] == 0 + assert attrs[SpanAttr.MESSAGING_KAFKA_MESSAGE_OFFSET] == 0 + + if action == Action.PROCESS: + assert attrs[SpanAttr.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES] == len(msg) + assert attrs[SpanAttr.MESSAGING_OPERATION] == action + + if action == Action.PUBLISH: + assert attrs[SpanAttr.MESSAGING_OPERATION] == action + + if parent_span_id: + assert span.parent.span_id == parent_span_id + + async def test_batch( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + expected_msg_count = 3 + + @broker.subscriber(queue, batch=True, **self.subscriber_kwargs) + async def handler(m): + mock(m) + event.set() + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish_batch(1, "hi", 3, topic=queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + metrics = self.get_metrics(metric_reader) + spans = self.get_spans(trace_exporter) + _, publish, process = spans + + assert ( + publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + assert ( + process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + self.assert_metrics(metrics, count=expected_msg_count) + + assert event.is_set() + mock.assert_called_once_with([1, "hi", 3]) + + +@pytest.mark.kafka() +class TestPublishWithTelemetry(TestPublish): + def get_broker(self, apply_types: bool = False): + return KafkaBroker( + middlewares=(KafkaTelemetryMiddleware(),), + apply_types=apply_types, + ) + + +@pytest.mark.kafka() +class TestConsumeWithTelemetry(TestConsume): + def get_broker(self, apply_types: bool = False): + return KafkaBroker( + middlewares=(KafkaTelemetryMiddleware(),), + apply_types=apply_types, + ) diff --git a/tests/opentelemetry/nats/__init__.py b/tests/opentelemetry/nats/__init__.py new file mode 100644 index 0000000000..87ead90ee6 --- /dev/null +++ b/tests/opentelemetry/nats/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("nats") diff --git a/tests/opentelemetry/nats/test_nats.py b/tests/opentelemetry/nats/test_nats.py new file mode 100644 index 0000000000..b886e46d8f --- /dev/null +++ b/tests/opentelemetry/nats/test_nats.py @@ -0,0 +1,103 @@ +import asyncio +from unittest.mock import Mock + +import pytest +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.semconv.trace import SpanAttributes as SpanAttr + +from faststream.nats import JStream, NatsBroker, PullSub +from faststream.nats.opentelemetry import NatsTelemetryMiddleware +from tests.brokers.nats.test_consume import TestConsume +from tests.brokers.nats.test_publish import TestPublish + +from ..basic import LocalTelemetryTestcase + + +@pytest.fixture() +def stream(queue): + return JStream(queue) + + +@pytest.mark.nats() +class TestTelemetry(LocalTelemetryTestcase): + messaging_system = "nats" + include_messages_counters = True + broker_class = NatsBroker + telemetry_middleware_class = NatsTelemetryMiddleware + + async def test_batch( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + stream: JStream, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + expected_msg_count = 3 + + @broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(3, batch=True), + **self.subscriber_kwargs, + ) + async def handler(m): + mock(m) + event.set() + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish(1, queue)), + asyncio.create_task(broker.publish("hi", queue)), + asyncio.create_task(broker.publish(3, queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + process = spans[-1] + + assert ( + process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == 1 + assert pub_dur.data.data_points[0].count == expected_msg_count + + assert event.is_set() + mock.assert_called_once_with([1, "hi", 3]) + + +@pytest.mark.nats() +class TestPublishWithTelemetry(TestPublish): + def get_broker(self, apply_types: bool = False): + return NatsBroker( + middlewares=(NatsTelemetryMiddleware(),), + apply_types=apply_types, + ) + + +@pytest.mark.nats() +class TestConsumeWithTelemetry(TestConsume): + def get_broker(self, apply_types: bool = False): + return NatsBroker( + middlewares=(NatsTelemetryMiddleware(),), + apply_types=apply_types, + ) diff --git a/tests/opentelemetry/rabbit/__init__.py b/tests/opentelemetry/rabbit/__init__.py new file mode 100644 index 0000000000..ebec43fcd5 --- /dev/null +++ b/tests/opentelemetry/rabbit/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aio_pika") diff --git a/tests/opentelemetry/rabbit/test_rabbit.py b/tests/opentelemetry/rabbit/test_rabbit.py new file mode 100644 index 0000000000..120ac3cd1c --- /dev/null +++ b/tests/opentelemetry/rabbit/test_rabbit.py @@ -0,0 +1,83 @@ +from typing import Optional + +import pytest +from dirty_equals import IsInt, IsUUID +from opentelemetry.sdk.trace import Span +from opentelemetry.semconv.trace import SpanAttributes as SpanAttr +from opentelemetry.trace import SpanKind + +from faststream.opentelemetry.consts import MESSAGING_DESTINATION_PUBLISH_NAME +from faststream.opentelemetry.middleware import MessageAction as Action +from faststream.rabbit import RabbitBroker, RabbitExchange +from faststream.rabbit.opentelemetry import RabbitTelemetryMiddleware +from tests.brokers.rabbit.test_consume import TestConsume +from tests.brokers.rabbit.test_publish import TestPublish + +from ..basic import LocalTelemetryTestcase + + +@pytest.fixture() +def exchange(queue): + return RabbitExchange(name=queue) + + +@pytest.mark.rabbit() +class TestTelemetry(LocalTelemetryTestcase): + messaging_system = "rabbitmq" + include_messages_counters = False + broker_class = RabbitBroker + telemetry_middleware_class = RabbitTelemetryMiddleware + + def destination_name(self, queue: str) -> str: + return f"default.{queue}" + + def assert_span( + self, + span: Span, + action: str, + queue: str, + msg: str, + parent_span_id: Optional[str] = None, + ) -> None: + attrs = span.attributes + assert attrs[SpanAttr.MESSAGING_SYSTEM] == self.messaging_system + assert attrs[SpanAttr.MESSAGING_MESSAGE_CONVERSATION_ID] == IsUUID + assert attrs[SpanAttr.MESSAGING_RABBITMQ_DESTINATION_ROUTING_KEY] == queue + assert span.name == f"{self.destination_name(queue)} {action}" + assert span.kind in (SpanKind.CONSUMER, SpanKind.PRODUCER) + + if span.kind == SpanKind.PRODUCER and action in (Action.CREATE, Action.PUBLISH): + assert attrs[SpanAttr.MESSAGING_DESTINATION_NAME] == "" + + if span.kind == SpanKind.CONSUMER and action in (Action.CREATE, Action.PROCESS): + assert attrs[MESSAGING_DESTINATION_PUBLISH_NAME] == "" + assert attrs["messaging.rabbitmq.message.delivery_tag"] == IsInt + assert attrs[SpanAttr.MESSAGING_MESSAGE_ID] == IsUUID + + if action == Action.PROCESS: + assert attrs[SpanAttr.MESSAGING_MESSAGE_PAYLOAD_SIZE_BYTES] == len(msg) + assert attrs[SpanAttr.MESSAGING_OPERATION] == action + + if action == Action.PUBLISH: + assert attrs[SpanAttr.MESSAGING_OPERATION] == action + + if parent_span_id: + assert span.parent.span_id == parent_span_id + + +@pytest.mark.rabbit() +class TestPublishWithTelemetry(TestPublish): + def get_broker(self, apply_types: bool = False): + return RabbitBroker( + middlewares=(RabbitTelemetryMiddleware(),), + apply_types=apply_types, + ) + + +@pytest.mark.rabbit() +class TestConsumeWithTelemetry(TestConsume): + def get_broker(self, apply_types: bool = False): + return RabbitBroker( + middlewares=(RabbitTelemetryMiddleware(),), + apply_types=apply_types, + ) diff --git a/tests/opentelemetry/redis/__init__.py b/tests/opentelemetry/redis/__init__.py new file mode 100644 index 0000000000..4752ef19b1 --- /dev/null +++ b/tests/opentelemetry/redis/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("redis") diff --git a/tests/opentelemetry/redis/test_redis.py b/tests/opentelemetry/redis/test_redis.py new file mode 100644 index 0000000000..71e079cbac --- /dev/null +++ b/tests/opentelemetry/redis/test_redis.py @@ -0,0 +1,112 @@ +import asyncio +from unittest.mock import Mock + +import pytest +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.semconv.trace import SpanAttributes as SpanAttr + +from faststream.redis import ListSub, RedisBroker +from faststream.redis.opentelemetry import RedisTelemetryMiddleware +from tests.brokers.redis.test_consume import ( + TestConsume, + TestConsumeList, + TestConsumeStream, +) +from tests.brokers.redis.test_publish import TestPublish + +from ..basic import LocalTelemetryTestcase + + +@pytest.mark.redis() +class TestTelemetry(LocalTelemetryTestcase): + messaging_system = "redis" + include_messages_counters = True + broker_class = RedisBroker + telemetry_middleware_class = RedisTelemetryMiddleware + + async def test_batch( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + expected_msg_count = 3 + + @broker.subscriber(list=ListSub(queue, batch=True), **self.subscriber_kwargs) + async def handler(m): + mock(m) + event.set() + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish_batch(1, "hi", 3, list=queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + metrics = self.get_metrics(metric_reader) + spans = self.get_spans(trace_exporter) + _, publish, process = spans + + assert ( + publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + assert ( + process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + self.assert_metrics(metrics, count=expected_msg_count) + + assert event.is_set() + mock.assert_called_once_with([1, "hi", 3]) + + +@pytest.mark.redis() +class TestPublishWithTelemetry(TestPublish): + def get_broker(self, apply_types: bool = False): + return RedisBroker( + middlewares=(RedisTelemetryMiddleware(),), + apply_types=apply_types, + ) + + +@pytest.mark.redis() +class TestConsumeWithTelemetry(TestConsume): + def get_broker(self, apply_types: bool = False): + return RedisBroker( + middlewares=(RedisTelemetryMiddleware(),), + apply_types=apply_types, + ) + + +@pytest.mark.redis() +class TestConsumeListWithTelemetry(TestConsumeList): + def get_broker(self, apply_types: bool = False): + return RedisBroker( + middlewares=(RedisTelemetryMiddleware(),), + apply_types=apply_types, + ) + + +@pytest.mark.redis() +class TestConsumeStreamWithTelemetry(TestConsumeStream): + def get_broker(self, apply_types: bool = False): + return RedisBroker( + middlewares=(RedisTelemetryMiddleware(),), + apply_types=apply_types, + )