diff --git a/simple_history/models.py b/simple_history/models.py index 6dc4db9e..41a6f158 100644 --- a/simple_history/models.py +++ b/simple_history/models.py @@ -3,6 +3,7 @@ import uuid import warnings from functools import partial +from typing import TypeVar from django.apps import apps from django.conf import settings @@ -15,7 +16,9 @@ from django.db.models.fields.related import ForeignKey from django.db.models.fields.related_descriptors import ( ForwardManyToOneDescriptor, + ForwardOneToOneDescriptor, ReverseManyToOneDescriptor, + ReverseOneToOneDescriptor, create_reverse_many_to_one_manager, ) from django.db.models.query import QuerySet @@ -45,6 +48,13 @@ except ImportError: from threading import local as LocalContext + +# __set__ value type +_ST = TypeVar("_ST") +# __get__ return type +_GT = TypeVar("_GT") + + registered_models = {} @@ -909,6 +919,23 @@ def to_historic(instance): return getattr(instance, SIMPLE_HISTORY_REVERSE_ATTR_NAME, None) +class HistoricForwardOneToOneDescriptor( + ForwardOneToOneDescriptor, HistoricForwardManyToOneDescriptor +): + pass + + +class HistoricReverseOneToOneDescriptor( + ReverseOneToOneDescriptor, HistoricReverseManyToOneDescriptor +): + pass + + +class HistoricOneToOneField(models.OneToOneField[_ST, _GT]): + forward_related_accessor_class = HistoricForwardOneToOneDescriptor + related_accessor_class = HistoricReverseOneToOneDescriptor + + class HistoricalObjectDescriptor: def __init__(self, model, fields_included): self.model = model