diff --git a/simple_history/tests/tests/test_utils.py b/simple_history/tests/tests/test_utils.py index 7db701d9..ed034f98 100644 --- a/simple_history/tests/tests/test_utils.py +++ b/simple_history/tests/tests/test_utils.py @@ -37,6 +37,7 @@ get_m2m_reverse_field_name, update_change_reason, ) +from .utils import db_supports_returning_autofield_pks User = get_user_model() @@ -155,6 +156,28 @@ def test_bulk_create_history_with_disabled_setting(self): self.assertEqual(Poll.objects.count(), 5) self.assertEqual(Poll.history.count(), 0) + def test_bulk_create_history_without_pks(self): + for poll in self.data: + poll.pk = None + + # An extra query must be made on some DBs to retrieve the auto-generated PKs + with self.assertNumQueries(2 if db_supports_returning_autofield_pks() else 3): + bulk_create_with_history(self.data, Poll) + + self.assertEqual(Poll.objects.count(), 5) + self.assertEqual(Poll.history.count(), 5) + + def test_bulk_create_history_without_some_pks(self): + self.data[1].pk = None + self.data[3].pk = None + + # An extra query must be made on some DBs to retrieve the auto-generated PKs + with self.assertNumQueries(3 if db_supports_returning_autofield_pks() else 4): + bulk_create_with_history(self.data, Poll) + + self.assertEqual(Poll.objects.count(), 5) + self.assertEqual(Poll.history.count(), 5) + def test_bulk_create_history_alternative_manager(self): bulk_create_with_history( self.data, diff --git a/simple_history/tests/tests/utils.py b/simple_history/tests/tests/utils.py index ae6fe949..461f85e8 100644 --- a/simple_history/tests/tests/utils.py +++ b/simple_history/tests/tests/utils.py @@ -1,8 +1,10 @@ from enum import Enum from typing import Type +from functools import cache from django.conf import settings from django.db.models import Model +from django.db import connection from django.test import TestCase request_middleware = "simple_history.middleware.HistoryRequestMiddleware" @@ -35,6 +37,16 @@ def assertRecordValues(self, record, klass: Type[Model], values_dict: dict): self.assertEqual(getattr(record.history_object, key), value) +@cache +def db_supports_returning_autofield_pks() -> bool: + # See https://docs.djangoproject.com/en/stable/ref/models/querysets/#bulk-create + return connection.display_name.lower() in { + "postgresql", + "mariadb", + "sqlite", + } + + class TestDbRouter: def db_for_read(self, model, **hints): if model._meta.app_label == "external":