diff --git a/dj_cqrs/mixins.py b/dj_cqrs/mixins.py index a4cfb92..dbb62c9 100644 --- a/dj_cqrs/mixins.py +++ b/dj_cqrs/mixins.py @@ -118,11 +118,12 @@ def save(self, *args, **kwargs): def save_tracked_fields(self): if hasattr(self, FIELDS_TRACKER_FIELD_NAME): tracker = getattr(self, FIELDS_TRACKER_FIELD_NAME) - if self._state.adding: - data = tracker.changed_initial() - else: - data = tracker.changed() - setattr(self, TRACKED_FIELDS_ATTR_NAME, data) + if self.is_initial_cqrs_save: + if self._state.adding: + data = tracker.changed_initial() + else: + data = tracker.changed() + setattr(self, TRACKED_FIELDS_ATTR_NAME, data) @property def _update_cqrs_fields_default(self): diff --git a/requirements/dev.txt b/requirements/dev.txt index 062c7ac..bc46d0f 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,4 +1,4 @@ -Django >= 1.11.20 +Django>= 1.11.20,<4 pika>=1.0.0 kombu==4.6.* ujson==3.0.0 diff --git a/tests/test_master/test_mixin.py b/tests/test_master/test_mixin.py index 5913448..e93cccb 100644 --- a/tests/test_master/test_mixin.py +++ b/tests/test_master/test_mixin.py @@ -734,7 +734,7 @@ def test_transaction_instance_saved_once_simple_case(mocker): mapper = ( (i0.pk, 0, 'old', None), - (i1.pk, 0, '2', '1'), + (i1.pk, 0, '2', None), (i2.pk, 0, 'a', None), (i3.pk, 0, '.', None), (i0.pk, 1, 'new', 'old'), @@ -749,6 +749,27 @@ def test_transaction_instance_saved_once_simple_case(mocker): assert payload.previous_data['char_field'] == expected_data[3] +@pytest.mark.django_db(transaction=True) +def test_transaction_instance_saved_multiple_times_previous_data(mocker): + publisher_mock = mocker.patch('dj_cqrs.controller.producer.produce') + instance = models.TrackedFieldsParentModel.objects.create(char_field='db_value') + + with transaction.atomic(): + instance.refresh_from_db() + instance.char_field = 'save_1' + instance.save() + instance.char_field = 'save_2' + instance.save() + + assert publisher_mock.call_count == 2 + payload_create = publisher_mock.call_args_list[0][0][0] + payload_update = publisher_mock.call_args_list[1][0][0] + assert payload_create.instance_data['char_field'] == 'db_value' + assert payload_create.previous_data['char_field'] is None + assert payload_update.instance_data['char_field'] == 'save_2' + assert payload_update.previous_data['char_field'] == 'db_value' + + @pytest.mark.django_db(transaction=True) def test_cqrs_saves_count_lifecycle(): instance = models.TrackedFieldsParentModel(char_field='1')