diff --git a/dj_cqrs/transport/base.py b/dj_cqrs/transport/base.py index e66abc0..138175b 100644 --- a/dj_cqrs/transport/base.py +++ b/dj_cqrs/transport/base.py @@ -30,3 +30,8 @@ def produce(payload): def consume(*args, **kwargs): """Receive data from master model.""" raise NotImplementedError + + @staticmethod + def clean_connection(*args, **kwargs): + """Clean transport connection. Here you can close all connections that you have""" + raise NotImplementedError diff --git a/dj_cqrs/transport/kombu.py b/dj_cqrs/transport/kombu.py index c5c474c..8b4f6db 100644 --- a/dj_cqrs/transport/kombu.py +++ b/dj_cqrs/transport/kombu.py @@ -69,6 +69,11 @@ def get_consumers(self, Consumer, channel): class KombuTransport(LoggingMixin, BaseTransport): CONSUMER_RETRY_TIMEOUT = 5 + @classmethod + def clean_connection(cls): + """Nothing to do here""" + pass + @classmethod def consume(cls): queue_name, prefetch_count = cls._get_consumer_settings() diff --git a/dj_cqrs/transport/rabbit_mq.py b/dj_cqrs/transport/rabbit_mq.py index 7d19c68..9208330 100644 --- a/dj_cqrs/transport/rabbit_mq.py +++ b/dj_cqrs/transport/rabbit_mq.py @@ -6,7 +6,6 @@ from socket import gaierror from urllib.parse import unquote, urlparse - import ujson from django.conf import settings from pika import exceptions, BasicProperties, BlockingConnection, ConnectionParameters, credentials @@ -24,6 +23,16 @@ class RabbitMQTransport(LoggingMixin, BaseTransport): CONSUMER_RETRY_TIMEOUT = 5 + _producer_connection = None + _producer_channel = None + + @classmethod + def clean_connection(cls): + if cls._producer_connection and not cls._producer_connection.is_closed: + cls._producer_connection.close() + cls._producer_connection = None + cls._producer_channel = None + @classmethod def consume(cls): consumer_rabbit_settings = cls._get_consumer_settings() @@ -43,18 +52,19 @@ def consume(cls): logger.error('AMQP connection error. Reconnecting...') time.sleep(cls.CONSUMER_RETRY_TIMEOUT) finally: - if connection: + if connection and not connection.is_closed: connection.close() @classmethod def produce(cls, payload): + # TODO: try to produce and reconnect several times, now leave as before + # if cannot publish message - drop it and try to reconnect on next event rmq_settings = cls._get_common_settings() exchange = rmq_settings[-1] - connection = None try: # Decided not to create context-manager to stay within the class - connection, channel = cls._get_producer_rmq_objects(*rmq_settings) + _, channel = cls._get_producer_rmq_objects(*rmq_settings) cls._produce_message(channel, exchange, payload) cls.log_produced(payload) @@ -62,9 +72,9 @@ def produce(cls, payload): logger.error("CQRS couldn't be published: pk = {} ({}).".format( payload.pk, payload.cqrs_id, )) - finally: - if connection: - connection.close() + + # in case of any error - close connection and try to reconnect + cls.clean_connection() @classmethod def _consume_message(cls, ch, method, properties, body): @@ -114,7 +124,7 @@ def _produce_message(cls, channel, exchange, payload): properties=BasicProperties( content_type='text/plain', delivery_mode=2, # make message persistent - expiration='60000', # milliseconds + expiration=settings.CQRS.get('MESSAGE_TTL', '60000'), # milliseconds ) ) @@ -159,18 +169,22 @@ def _get_consumer_rmq_objects(cls, host, port, creds, exchange, queue_name, pref @classmethod def _get_producer_rmq_objects(cls, host, port, creds, exchange): - connection = BlockingConnection( - ConnectionParameters( - host=host, - port=port, - credentials=creds, - blocked_connection_timeout=10, - ), - ) - channel = connection.channel() - cls._declare_exchange(channel, exchange) + if cls._producer_connection is None: + connection = BlockingConnection( + ConnectionParameters( + host=host, + port=port, + credentials=creds, + blocked_connection_timeout=10, + ), + ) + channel = connection.channel() + cls._declare_exchange(channel, exchange) - return connection, channel + cls._producer_connection = connection + cls._producer_channel = channel + + return cls._producer_connection, cls._producer_channel @staticmethod def _declare_exchange(channel, exchange): diff --git a/integration_tests/tests/conftest.py b/integration_tests/tests/conftest.py index 72d0e69..063c76b 100644 --- a/integration_tests/tests/conftest.py +++ b/integration_tests/tests/conftest.py @@ -5,6 +5,8 @@ from integration_tests.tests.utils import REPLICA_TABLES +from dj_cqrs.transport import current_transport + @pytest.fixture def replica_cursor(): @@ -21,3 +23,10 @@ def replica_cursor(): cursor.close() connection.close() + + +@pytest.fixture +def clean_rabbit_transport_connection(): + current_transport.clean_connection() + + yield diff --git a/integration_tests/tests/test_asynchronous_consuming.py b/integration_tests/tests/test_asynchronous_consuming.py index 44d9abe..b4769e3 100644 --- a/integration_tests/tests/test_asynchronous_consuming.py +++ b/integration_tests/tests/test_asynchronous_consuming.py @@ -10,7 +10,8 @@ @pytest.mark.django_db(transaction=True) -def test_both_consumers_consume(replica_cursor): +def test_both_consumers_consume(settings, replica_cursor, clean_rabbit_transport_connection): + settings.CQRS['MESSAGE_TTL'] = '4000' assert count_replica_rows(replica_cursor, REPLICA_BASIC_TABLE) == 0 assert count_replica_rows(replica_cursor, REPLICA_EVENT_TABLE) == 0 @@ -23,7 +24,7 @@ def test_both_consumers_consume(replica_cursor): ]) BasicFieldsModel.call_post_bulk_create(master_instances) - transport_delay(3) + transport_delay(5) assert count_replica_rows(replica_cursor, REPLICA_BASIC_TABLE) == 9 assert count_replica_rows(replica_cursor, REPLICA_EVENT_TABLE) == 9 @@ -32,13 +33,18 @@ def test_both_consumers_consume(replica_cursor): @pytest.mark.django_db(transaction=True) -def test_de_duplication(replica_cursor): +def test_de_duplication(settings, replica_cursor, clean_rabbit_transport_connection): + settings.CQRS['MESSAGE_TTL'] = '4000' assert count_replica_rows(replica_cursor, REPLICA_BASIC_TABLE) == 0 assert count_replica_rows(replica_cursor, REPLICA_EVENT_TABLE) == 0 - master_instance = BasicFieldsModel.objects.create(int_field=1, char_field='text') - BasicFieldsModel.call_post_bulk_create([master_instance for _ in range(9)]) + master_instance = BasicFieldsModel.objects.create(int_field=21, char_field='text') + BasicFieldsModel.call_post_bulk_create([master_instance]) + transport_delay(5) - transport_delay(3) + replica_cursor.execute('TRUNCATE TABLE {};'.format(REPLICA_EVENT_TABLE)) + BasicFieldsModel.call_post_bulk_create([master_instance for _ in range(10)]) + + transport_delay(5) assert count_replica_rows(replica_cursor, REPLICA_BASIC_TABLE) == 1 assert count_replica_rows(replica_cursor, REPLICA_EVENT_TABLE) == 10 diff --git a/integration_tests/tests/test_bulk_operations.py b/integration_tests/tests/test_bulk_operations.py index 6d409a3..277e101 100644 --- a/integration_tests/tests/test_bulk_operations.py +++ b/integration_tests/tests/test_bulk_operations.py @@ -9,7 +9,7 @@ @pytest.mark.django_db(transaction=True) -def test_flow(replica_cursor): +def test_flow(replica_cursor, clean_rabbit_transport_connection): assert count_replica_rows(replica_cursor, REPLICA_BASIC_TABLE) == 0 # Create diff --git a/integration_tests/tests/test_single_basic_instance.py b/integration_tests/tests/test_single_basic_instance.py index 031a21d..750eca7 100644 --- a/integration_tests/tests/test_single_basic_instance.py +++ b/integration_tests/tests/test_single_basic_instance.py @@ -10,7 +10,7 @@ @pytest.mark.django_db(transaction=True) -def test_flow(replica_cursor): +def test_flow(replica_cursor, clean_rabbit_transport_connection): assert count_replica_rows(replica_cursor, REPLICA_BASIC_TABLE) == 0 # Create diff --git a/integration_tests/tests/test_sync_to_a_certain_service.py b/integration_tests/tests/test_sync_to_a_certain_service.py index ce1bc35..9d1aa48 100644 --- a/integration_tests/tests/test_sync_to_a_certain_service.py +++ b/integration_tests/tests/test_sync_to_a_certain_service.py @@ -9,7 +9,7 @@ @pytest.mark.django_db(transaction=True) -def test_flow(replica_cursor, mocker): +def test_flow(replica_cursor, mocker, clean_rabbit_transport_connection): assert count_replica_rows(replica_cursor, REPLICA_BASIC_TABLE) == 0 # Create