Skip to content

Commit

Permalink
[TODO: PR numbers] Code cleanup of some util functions
Browse files Browse the repository at this point in the history
Incl. removing `HistoryManager.get_super_queryset()`
and renaming the following `utils` functions:
* get_history_manager_from_history -> get_historical_records_of_instance
* get_app_model_primary_key_name -> get_pk_name

Also added some tests for some of the changed util functions.
  • Loading branch information
ddabble committed Sep 8, 2024
1 parent 718e732 commit d150848
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 40 deletions.
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ Unreleased
- Added ``delete_without_historical_record()`` to all history-tracked model objects,
which complements ``save_without_historical_record()`` (gh-13xx)

**Breaking changes:**

- Removed ``HistoryManager.get_super_queryset()`` (gh-13xx)
- Renamed the ``utils`` functions ``get_history_manager_from_history()``
to ``get_historical_records_of_instance()`` and ``get_app_model_primary_key_name()``
to ``get_pk_name()`` (gh-13xx)

**Deprecations:**

- Deprecated the undocumented ``HistoricalRecords.thread`` - use
Expand Down
16 changes: 5 additions & 11 deletions simple_history/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from django.db.models import Exists, OuterRef, Q, QuerySet
from django.utils import timezone

from simple_history.utils import (
get_app_model_primary_key_name,
get_change_reason_from_object,
)
from . import utils

# when converting a historical record to an instance, this attribute is added
# to the instance so that code can reverse the instance to its historical record
Expand Down Expand Up @@ -118,16 +115,13 @@ def __init__(self, model, instance=None):
self.model = model
self.instance = instance

def get_super_queryset(self):
return super().get_queryset()

def get_queryset(self):
qs = self.get_super_queryset()
qs = super().get_queryset()
if self.instance is None:
return qs

key_name = get_app_model_primary_key_name(self.instance)
return self.get_super_queryset().filter(**{key_name: self.instance.pk})
pk_name = utils.get_pk_name(self.instance._meta.model)
return qs.filter(**{pk_name: self.instance.pk})

def most_recent(self):
"""
Expand Down Expand Up @@ -241,7 +235,7 @@ def bulk_history_create(
instance, "_history_date", default_date or timezone.now()
),
history_user=history_user,
history_change_reason=get_change_reason_from_object(instance)
history_change_reason=utils.get_change_reason_from_object(instance)
or default_change_reason,
history_type=history_type,
**{
Expand Down
4 changes: 2 additions & 2 deletions simple_history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def get_next_record(self):
"""
Get the next history record for the instance. `None` if last.
"""
history = utils.get_history_manager_from_history(self)
history = utils.get_historical_records_of_instance(self)
return (
history.filter(history_date__gt=self.history_date)
.order_by("history_date")
Expand All @@ -578,7 +578,7 @@ def get_prev_record(self):
"""
Get the previous history record for the instance. `None` if first.
"""
history = utils.get_history_manager_from_history(self)
history = utils.get_historical_records_of_instance(self)
return (
history.filter(history_date__lt=self.history_date)
.order_by("history_date")
Expand Down
211 changes: 211 additions & 0 deletions simple_history/tests/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,88 @@
import unittest
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Type
from unittest import skipUnless
from unittest.mock import Mock, patch

import django
from django.contrib.auth import get_user_model
from django.db import IntegrityError, transaction
from django.db.models import Model
from django.test import TestCase, TransactionTestCase, override_settings
from django.utils import timezone

from simple_history.exceptions import AlternativeManagerError, NotHistoricalModelError
from simple_history.manager import HistoryManager
from simple_history.models import HistoricalChanges
from simple_history.utils import (
bulk_create_with_history,
bulk_update_with_history,
get_historical_records_of_instance,
get_history_manager_for_model,
get_history_model_for_model,
get_m2m_field_name,
get_m2m_reverse_field_name,
get_pk_name,
update_change_reason,
)

from ..external import models as external
from ..models import (
AbstractBase,
AbstractModelCallable1,
BaseModel,
Book,
BulkCreateManyToManyModel,
Choice,
ConcreteAttr,
ConcreteExternal,
ConcreteUtil,
Contact,
ContactRegister,
CustomManagerNameModel,
Document,
ExternalModelSpecifiedWithAppParam,
ExternalModelWithAppLabel,
FirstLevelInheritedModel,
HardbackBook,
HistoricalBook,
HistoricalPoll,
HistoricalPollInfo,
InheritTracking1,
ModelWithHistoryInDifferentApp,
ModelWithHistoryUsingBaseModelDb,
OverrideModelNameAsCallable,
OverrideModelNameRegisterMethod1,
OverrideModelNameUsingBaseModel1,
Place,
Poll,
PollChildBookWithManyToMany,
PollChildRestaurantWithManyToMany,
PollInfo,
PollParentWithManyToMany,
PollWithAlternativeManager,
PollWithCustomManager,
PollWithExcludedFKField,
PollWithExcludeFields,
PollWithHistoricalSessionAttr,
PollWithManyToMany,
PollWithManyToManyCustomHistoryID,
PollWithManyToManyWithIPAddress,
PollWithQuerySetCustomizations,
PollWithSelfManyToMany,
PollWithSeveralManyToMany,
PollWithUniqueQuestion,
Profile,
Restaurant,
Street,
TestHistoricParticipanToHistoricOrganization,
TestParticipantToHistoricOrganization,
TrackedAbstractBaseA,
TrackedConcreteBase,
TrackedWithAbstractBase,
TrackedWithConcreteBase,
Voter,
)

User = get_user_model()
Expand All @@ -53,6 +99,171 @@ def test_update_change_reason_with_excluded_fields(self):
self.assertEqual(most_recent.history_change_reason, "Test change reason.")


@dataclass
class HistoryTrackedModelTestInfo:
model: Type[Model]
history_manager_name: Optional[str]

def __init__(
self,
model: Type[Model],
history_manager_name: Optional[str] = "history",
):
self.model = model
self.history_manager_name = history_manager_name


class GetHistoryManagerAndModelHelpersTestCase(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()

H = HistoryTrackedModelTestInfo
cls.history_tracked_models = [
H(Choice),
H(ConcreteAttr),
H(ConcreteExternal),
H(ConcreteUtil),
H(Contact),
H(ContactRegister),
H(CustomManagerNameModel, "log"),
H(ExternalModelSpecifiedWithAppParam, "histories"),
H(ExternalModelWithAppLabel),
H(InheritTracking1),
H(ModelWithHistoryInDifferentApp),
H(ModelWithHistoryUsingBaseModelDb),
H(OverrideModelNameAsCallable),
H(OverrideModelNameRegisterMethod1),
H(OverrideModelNameUsingBaseModel1),
H(Poll),
H(PollChildBookWithManyToMany),
H(PollWithAlternativeManager),
H(PollWithCustomManager),
H(PollWithExcludedFKField),
H(PollWithHistoricalSessionAttr),
H(PollWithManyToMany),
H(PollWithManyToManyCustomHistoryID),
H(PollWithManyToManyWithIPAddress),
H(PollWithQuerySetCustomizations),
H(PollWithSelfManyToMany),
H(Restaurant, "updates"),
H(TestHistoricParticipanToHistoricOrganization),
H(TrackedConcreteBase),
H(TrackedWithAbstractBase),
H(TrackedWithConcreteBase),
H(Voter),
H(external.ExternalModel),
H(external.ExternalModelRegistered, "histories"),
H(external.Poll),
]
cls.models_without_history_manager = [
H(AbstractBase, None),
H(AbstractModelCallable1, None),
H(BaseModel, None),
H(FirstLevelInheritedModel, None),
H(HardbackBook, None),
H(Place, None),
H(PollParentWithManyToMany, None),
H(Profile, None),
H(TestParticipantToHistoricOrganization, None),
H(TrackedAbstractBaseA, None),
]

def test__get_history_manager_for_model(self):
"""Test that ``get_history_manager_for_model()`` returns the expected value
for various models."""

def assert_history_manager(history_manager, info: HistoryTrackedModelTestInfo):
expected_manager = getattr(info.model, info.history_manager_name)
expected_historical_model = expected_manager.model
historical_model = history_manager.model
# Can't compare the managers directly, as the history manager classes are
# dynamically created through `HistoryDescriptor`
self.assertIsInstance(history_manager, HistoryManager)
self.assertIsInstance(expected_manager, HistoryManager)
self.assertTrue(issubclass(historical_model, HistoricalChanges))
self.assertEqual(historical_model.instance_type, info.model)
self.assertEqual(historical_model, expected_historical_model)

for model_info in self.history_tracked_models:
with self.subTest(model_info=model_info):
model = model_info.model
manager = get_history_manager_for_model(model)
assert_history_manager(manager, model_info)

for model_info in self.models_without_history_manager:
with self.subTest(model_info=model_info):
model = model_info.model
with self.assertRaises(NotHistoricalModelError):
get_history_manager_for_model(model)

def test__get_history_model_for_model(self):
"""Test that ``get_history_model_for_model()`` returns the expected value
for various models."""
for model_info in self.history_tracked_models:
with self.subTest(model_info=model_info):
model = model_info.model
historical_model = get_history_model_for_model(model)
self.assertTrue(issubclass(historical_model, HistoricalChanges))
self.assertEqual(historical_model.instance_type, model)

for model_info in self.models_without_history_manager:
with self.subTest(model_info=model_info):
model = model_info.model
with self.assertRaises(NotHistoricalModelError):
get_history_model_for_model(model)

def test__get_pk_name(self):
"""Test that ``get_pk_name()`` returns the expected value for various models."""
self.assertEqual(get_pk_name(Poll), "id")
self.assertEqual(get_pk_name(PollInfo), "poll_id")
self.assertEqual(get_pk_name(Book), "isbn")

self.assertEqual(get_pk_name(HistoricalPoll), "history_id")
self.assertEqual(get_pk_name(HistoricalPollInfo), "history_id")
self.assertEqual(get_pk_name(HistoricalBook), "history_id")


class GetHistoricalRecordsOfInstanceTestCase(TestCase):
def test__get_historical_records_of_instance(self):
"""Test that ``get_historical_records_of_instance()`` returns the expected
queryset for history-tracked model instances."""
poll1 = Poll.objects.create(pub_date=timezone.now())
poll1_history = poll1.history.all()
(record1_1,) = poll1_history
self.assertQuerySetEqual(
get_historical_records_of_instance(record1_1),
poll1_history,
)

poll2 = Poll.objects.create(pub_date=timezone.now())
poll2.question = "?"
poll2.save()
poll2_history = poll2.history.all()
(record2_2, record2_1) = poll2_history
self.assertQuerySetEqual(
get_historical_records_of_instance(record2_1),
poll2_history,
)
self.assertQuerySetEqual(
get_historical_records_of_instance(record2_2),
poll2_history,
)

poll3 = Poll.objects.create(id=123, pub_date=timezone.now())
poll3.delete()
poll3_history = Poll.history.filter(id=123)
(record3_2, record3_1) = poll3_history
self.assertQuerySetEqual(
get_historical_records_of_instance(record3_1),
poll3_history,
)
self.assertQuerySetEqual(
get_historical_records_of_instance(record3_2),
poll3_history,
)


class GetM2MFieldNamesTestCase(unittest.TestCase):
def test__get_m2m_field_name__returns_expected_value(self):
def field_names(model):
Expand Down
15 changes: 8 additions & 7 deletions simple_history/tests/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ def assertRecordValues(self, record, klass: Type[Model], values_dict: dict):
:param klass: The type of the history-tracked class of ``record``.
:param values_dict: Field names of ``record`` mapped to their expected values.
"""
for key, value in values_dict.items():
self.assertEqual(getattr(record, key), value)

self.assertEqual(record.history_object.__class__, klass)
for key, value in values_dict.items():
if key not in ("history_type", "history_change_reason"):
self.assertEqual(getattr(record.history_object, key), value)
for field_name, expected_value in values_dict.items():
self.assertEqual(getattr(record, field_name), expected_value)

history_object = record.history_object
self.assertEqual(history_object.__class__, klass)
for field_name, expected_value in values_dict.items():
if field_name not in ("history_type", "history_change_reason"):
self.assertEqual(getattr(history_object, field_name), expected_value)


class TestDbRouter:
Expand Down
Loading

0 comments on commit d150848

Please sign in to comment.