diff --git a/simple_history/template_utils.py b/simple_history/template_utils.py index eb871c20..0f1e062b 100644 --- a/simple_history/template_utils.py +++ b/simple_history/template_utils.py @@ -3,14 +3,33 @@ from typing import Any, Dict, Final, List, Tuple, Type, Union from django.db.models import ManyToManyField, Model -from django.template.defaultfilters import truncatechars_html from django.utils.html import conditional_escape +from django.utils.safestring import SafeString, mark_safe from django.utils.text import capfirst from .models import HistoricalChanges, ModelChange, ModelDelta, PKOrRelatedObj from .utils import get_m2m_reverse_field_name +def conditional_str(obj: Any) -> str: + """ + Converts ``obj`` to a string, unless it's already one. + """ + if isinstance(obj, str): + return obj + return str(obj) + + +def is_safe_str(s: Any) -> bool: + """ + Returns whether ``s`` is a (presumably) pre-escaped string or not. + + This relies on the same ``__html__`` convention as Django's ``conditional_escape`` + does. + """ + return hasattr(s, "__html__") + + class HistoricalRecordContextHelper: """ Class containing various utilities for formatting the template context for @@ -58,17 +77,17 @@ def format_delta_change(self, change: ModelChange) -> ModelChange: Return a ``ModelChange`` object with fields formatted for being used as template context. """ + old = self.prepare_delta_change_value(change, change.old) + new = self.prepare_delta_change_value(change, change.new) - def format_value(value): - value = self.prepare_delta_change_value(change, value) - return self.stringify_delta_change_value(change, value) + old, new = self.stringify_delta_change_values(change, old, new) field_meta = self.model._meta.get_field(change.field) return dataclasses.replace( change, field=capfirst(field_meta.verbose_name), - old=format_value(change.old), - new=format_value(change.new), + old=old, + new=new, ) def prepare_delta_change_value( @@ -78,12 +97,11 @@ def prepare_delta_change_value( ) -> Any: """ Return the prepared value for the ``old`` and ``new`` fields of ``change``, - before it's passed through ``stringify_delta_change_value()`` (in + before it's passed through ``stringify_delta_change_values()`` (in ``format_delta_change()``). For example, if ``value`` is a list of M2M related objects, it could be - "prepared" by replacing the related objects with custom string representations, - or by returning a more nicely formatted HTML string. + "prepared" by replacing the related objects with custom string representations. :param change: :param value: Either ``change.old`` or ``change.new``. @@ -99,23 +117,46 @@ def prepare_delta_change_value( display_value = value return display_value - def stringify_delta_change_value(self, change: ModelChange, value: Any) -> str: + def stringify_delta_change_values( + self, change: ModelChange, old: Any, new: Any + ) -> Tuple[SafeString, SafeString]: """ - Return the displayed value for the ``old`` and ``new`` fields of ``change``, - after it's prepared by ``prepare_delta_change_value()``. + Called by ``format_delta_change()`` after ``old`` and ``new`` have been + prepared by ``prepare_delta_change_value()``. - :param change: - :param value: Either ``change.old`` or ``change.new``, as returned by - ``prepare_delta_change_value()``. + Return a tuple -- ``(old, new)`` -- where each element has been + escaped/sanitized and turned into strings, ready to be displayed in a template. + These can be HTML strings (remember to pass them through ``mark_safe()`` *after* + escaping). """ - # If `value` is a list, stringify it using `str()` instead of `repr()` - # (the latter of which is the default when stringifying lists) - if isinstance(value, list): - value = f'[{", ".join(map(str, value))}]' - value = conditional_escape(value) - value = truncatechars_html(value, self.max_displayed_delta_change_chars) - return value + def stringify_value(value) -> Union[str, SafeString]: + # If `value` is a list, stringify each element using `str()` instead of + # `repr()` (the latter is the default when calling `list.__str__()`) + if isinstance(value, list): + string = f"[{', '.join(map(conditional_str, value))}]" + # If all elements are safe strings, reapply `mark_safe()` + if all(map(is_safe_str, value)): + string = mark_safe(string) # nosec + else: + string = conditional_str(value) + return string + + old_str, new_str = stringify_value(old), stringify_value(new) + diff_display = self.get_obj_diff_display() + old_short, new_short = diff_display.common_shorten_repr(old_str, new_str) + # Escape *after* shortening, as any shortened, previously safe HTML strings have + # likely been mangled. Other strings that have not been shortened, should have + # their "safeness" unchanged + return conditional_escape(old_short), conditional_escape(new_short) + + def get_obj_diff_display(self) -> "ObjDiffDisplay": + """ + Return an instance of ``ObjDiffDisplay`` that will be used in + ``stringify_delta_change_values()`` to display the difference between + the old and new values of a ``ModelChange``. + """ + return ObjDiffDisplay(max_length=self.max_displayed_delta_change_chars) class ObjDiffDisplay: @@ -158,45 +199,47 @@ def common_shorten_repr(self, *args: Any) -> Tuple[str, ...]: so that the first differences between the strings (after a potential common prefix in all of them) are lined up. """ - args = tuple(map(self.safe_repr, args)) - maxlen = max(map(len, args)) - if maxlen <= self.max_length: + args = tuple(map(conditional_str, args)) + max_len = max(map(len, args)) + if max_len <= self.max_length: return args prefix = commonprefix(args) - prefixlen = len(prefix) + prefix_len = len(prefix) common_len = self.max_length - ( - maxlen - prefixlen + self.min_begin_len + self.placeholder_len + max_len - prefix_len + self.min_begin_len + self.placeholder_len ) if common_len > self.min_common_len: assert ( self.min_begin_len + self.placeholder_len + self.min_common_len - + (maxlen - prefixlen) + + (max_len - prefix_len) < self.max_length ) # nosec prefix = self.shorten(prefix, self.min_begin_len, common_len) - return tuple(prefix + s[prefixlen:] for s in args) + return tuple(f"{prefix}{s[prefix_len:]}" for s in args) prefix = self.shorten(prefix, self.min_begin_len, self.min_common_len) return tuple( - prefix + self.shorten(s[prefixlen:], self.min_diff_len, self.min_end_len) + prefix + self.shorten(s[prefix_len:], self.min_diff_len, self.min_end_len) for s in args ) - def safe_repr(self, obj: Any, short=False) -> str: - try: - result = repr(obj) - except Exception: - result = object.__repr__(obj) - if not short or len(result) < self.max_length: - return result - return result[: self.max_length] + " [truncated]..." - - def shorten(self, s: str, prefixlen: int, suffixlen: int) -> str: - skip = len(s) - prefixlen - suffixlen + def shorten(self, s: str, prefix_len: int, suffix_len: int) -> str: + skip = len(s) - prefix_len - suffix_len if skip > self.placeholder_len: - s = "%s[%d chars]%s" % (s[:prefixlen], skip, s[len(s) - suffixlen :]) + suffix_index = len(s) - suffix_len + s = self.shortened_str(s[:prefix_len], skip, s[suffix_index:]) return s + + def shortened_str(self, prefix: str, num_skipped_chars: int, suffix: str) -> str: + """ + Return a shortened version of the string representation of one of the args + passed to ``common_shorten_repr()``. + This should be in the format ``f"{prefix}{skip_str}{suffix}"``, where + ``skip_str`` is a string indicating how many characters (``num_skipped_chars``) + of the string representation were skipped between ``prefix`` and ``suffix``. + """ + return f"{prefix}[{num_skipped_chars:d} chars]{suffix}" diff --git a/simple_history/tests/admin.py b/simple_history/tests/admin.py index 2395f47d..cc6aaf90 100644 --- a/simple_history/tests/admin.py +++ b/simple_history/tests/admin.py @@ -1,6 +1,8 @@ from django.contrib import admin +from django.utils.safestring import SafeString, mark_safe from simple_history.admin import SimpleHistoryAdmin +from simple_history.template_utils import HistoricalRecordContextHelper from simple_history.tests.external.models import ExternalModelWithCustomUserIdField from .models import ( @@ -12,6 +14,7 @@ FileModel, Paper, Person, + Place, Planet, Poll, PollWithManyToMany, @@ -44,6 +47,27 @@ def test_method(self, obj): history_list_display = ["title", "test_method"] +class HistoricalPollWithManyToManyContextHelper(HistoricalRecordContextHelper): + def prepare_delta_change_value(self, change, value): + display_value = super().prepare_delta_change_value(change, value) + if change.field == "places": + assert isinstance(display_value, list) + assert all(isinstance(place, Place) for place in display_value) + + places = sorted(display_value, key=lambda place: place.name) + display_value = list(map(self.place_display, places)) + return display_value + + @staticmethod + def place_display(place: Place) -> SafeString: + return mark_safe(f"{place.name}") + + +class PollWithManyToManyAdmin(SimpleHistoryAdmin): + def get_historical_record_context_helper(self, request, historical_record): + return HistoricalPollWithManyToManyContextHelper(self.model, historical_record) + + admin.site.register(Book, SimpleHistoryAdmin) admin.site.register(Choice, ChoiceAdmin) admin.site.register(ConcreteExternal, SimpleHistoryAdmin) @@ -55,4 +79,4 @@ def test_method(self, obj): admin.site.register(Person, PersonAdmin) admin.site.register(Planet, PlanetAdmin) admin.site.register(Poll, SimpleHistoryAdmin) -admin.site.register(PollWithManyToMany, SimpleHistoryAdmin) +admin.site.register(PollWithManyToMany, PollWithManyToManyAdmin) diff --git a/simple_history/tests/tests/test_admin.py b/simple_history/tests/tests/test_admin.py index 73dadfb2..02b5bdc5 100644 --- a/simple_history/tests/tests/test_admin.py +++ b/simple_history/tests/tests/test_admin.py @@ -159,6 +159,11 @@ def test_history_list_contains_diff_changes_for_foreign_key_fields(self): self.assertContains(response, f"Deleted poll (pk={poll1_pk})") self.assertContains(response, f"Deleted poll (pk={poll2_pk})") + @patch( + # Test without the customization in PollWithManyToMany's admin class + "simple_history.tests.admin.HistoricalPollWithManyToManyContextHelper", + HistoricalRecordContextHelper, + ) def test_history_list_contains_diff_changes_for_m2m_fields(self): self.login() poll = PollWithManyToMany(question="why?", pub_date=today) @@ -199,20 +204,54 @@ def test_history_list_contains_diff_changes_for_m2m_fields(self): def test_history_list_doesnt_contain_too_long_diff_changes(self): self.login() - repeated_chars = Poll._meta.get_field("question").max_length - 1 - poll = Poll(question=f"W{'A' * repeated_chars}", pub_date=today) - poll._history_user = self.user - poll.save() - poll.question = f"W{'E' * repeated_chars}" - poll.save() - response = self.client.get(get_history_url(poll)) - self.assertContains(response, "Question") + def create_and_change_poll(*, initial_question, changed_question) -> Poll: + poll = Poll(question=initial_question, pub_date=today) + poll._history_user = self.user + poll.save() + poll.question = changed_question + poll.save() + return poll + repeated_chars = ( - HistoricalRecordContextHelper.DEFAULT_MAX_DISPLAYED_DELTA_CHANGE_CHARS - 2 + HistoricalRecordContextHelper.DEFAULT_MAX_DISPLAYED_DELTA_CHANGE_CHARS + ) + + # Number of characters right on the limit + poll1 = create_and_change_poll( + initial_question="A" * repeated_chars, + changed_question="B" * repeated_chars, ) - self.assertContains(response, f"W{'A' * repeated_chars}…") - self.assertContains(response, f"W{'E' * repeated_chars}…") + response = self.client.get(get_history_url(poll1)) + self.assertContains(response, "Question:") + self.assertContains(response, "A" * repeated_chars) + self.assertContains(response, "B" * repeated_chars) + + # Number of characters just over the limit + poll2 = create_and_change_poll( + initial_question="A" * (repeated_chars + 1), + changed_question="B" * (repeated_chars + 1), + ) + response = self.client.get(get_history_url(poll2)) + self.assertContains(response, "Question:") + self.assertContains(response, f"{'A' * 61}[35 chars]AAAAA") + self.assertContains(response, f"{'B' * 61}[35 chars]BBBBB") + + def test_overriding__historical_record_context_helper__with_custom_m2m_string(self): + self.login() + + place1 = Place.objects.create(name="Place 1") + place2 = Place.objects.create(name="Place 2") + place3 = Place.objects.create(name="Place 3") + poll = PollWithManyToMany.objects.create(question="why?", pub_date=today) + poll.places.add(place1, place2) + poll.places.set([place3]) + + response = self.client.get(get_history_url(poll)) + self.assertContains(response, "Places:") + self.assertContains(response, "[]") + self.assertContains(response, "[Place 1, Place 2]") + self.assertContains(response, "[Place 3]") def test_history_list_custom_fields(self): model_name = self.user._meta.model_name diff --git a/simple_history/tests/tests/test_template_utils.py b/simple_history/tests/tests/test_template_utils.py new file mode 100644 index 00000000..d159ba81 --- /dev/null +++ b/simple_history/tests/tests/test_template_utils.py @@ -0,0 +1,314 @@ +from datetime import datetime +from typing import Tuple + +from django.test import TestCase +from django.utils.dateparse import parse_datetime +from django.utils.safestring import mark_safe + +from simple_history.models import ModelChange, ModelDelta +from simple_history.template_utils import HistoricalRecordContextHelper, is_safe_str + +from ...tests.models import Choice, Place, Poll, PollWithManyToMany + + +class HistoricalRecordContextHelperTestCase(TestCase): + + def test__context_for_delta_changes__basic_usage_works_as_expected(self): + # --- Text and datetimes --- + + old_date = "2021-01-01 12:00:00" + poll = Poll.objects.create(question="old?", pub_date=parse_datetime(old_date)) + new_date = "2021-01-02 12:00:00" + poll.question = "new?" + poll.pub_date = parse_datetime(new_date) + poll.save() + + new, old = poll.history.all() + expected_context_list = [ + { + "field": "Date published", + "old": old_date, + "new": new_date, + }, + { + "field": "Question", + "old": "old?", + "new": "new?", + }, + ] + self.assert__context_for_delta_changes__equal( + Poll, old, new, expected_context_list + ) + + # --- Foreign keys and ints --- + + poll1 = Poll.objects.create(question="1?", pub_date=datetime.now()) + poll2 = Poll.objects.create(question="2?", pub_date=datetime.now()) + choice = Choice.objects.create(poll=poll1, votes=1) + choice.poll = poll2 + choice.votes = 10 + choice.save() + + new, old = choice.history.all() + expected_context_list = [ + { + "field": "Poll", + "old": f"Poll object ({poll1.pk})", + "new": f"Poll object ({poll2.pk})", + }, + { + "field": "Votes", + "old": "1", + "new": "10", + }, + ] + self.assert__context_for_delta_changes__equal( + Choice, old, new, expected_context_list + ) + + # --- M2M objects, text and datetimes (across 3 records) --- + + poll = PollWithManyToMany.objects.create( + question="old?", pub_date=parse_datetime(old_date) + ) + poll.question = "new?" + poll.pub_date = parse_datetime(new_date) + poll.save() + place1 = Place.objects.create(name="Place 1") + place2 = Place.objects.create(name="Place 2") + poll.places.add(place1, place2) + + newest, _middle, oldest = poll.history.all() + expected_context_list = [ + # (The dicts should be sorted by the fields' attribute names) + { + "field": "Places", + "old": "[]", + "new": f"[Place object ({place1.pk}), Place object ({place2.pk})]", + }, + { + "field": "Date published", + "old": old_date, + "new": new_date, + }, + { + "field": "Question", + "old": "old?", + "new": "new?", + }, + ] + self.assert__context_for_delta_changes__equal( + PollWithManyToMany, oldest, newest, expected_context_list + ) + + def assert__context_for_delta_changes__equal( + self, model, old_record, new_record, expected_context_list + ): + delta = new_record.diff_against(old_record, foreign_keys_are_objs=True) + context_helper = HistoricalRecordContextHelper(model, new_record) + context_list = context_helper.context_for_delta_changes(delta) + self.assertListEqual(context_list, expected_context_list) + + def test__context_for_delta_changes__with_string_len_around_character_limit(self): + now = datetime.now() + + def test_context_dict( + *, initial_question, changed_question, expected_old, expected_new + ) -> None: + poll = Poll.objects.create(question=initial_question, pub_date=now) + poll.question = changed_question + poll.save() + new, old = poll.history.all() + + expected_context_dict = { + "field": "Question", + "old": expected_old, + "new": expected_new, + } + self.assert__context_for_delta_changes__equal( + Poll, old, new, [expected_context_dict] + ) + # Flipping the records should produce the same result (other than also + # flipping the expected "old" and "new" values, of course) + expected_context_dict = { + "field": "Question", + "old": expected_new, + "new": expected_old, + } + self.assert__context_for_delta_changes__equal( + Poll, new, old, [expected_context_dict] + ) + + # Check the character limit used in the assertions below + self.assertEqual( + HistoricalRecordContextHelper.DEFAULT_MAX_DISPLAYED_DELTA_CHANGE_CHARS, 100 + ) + + # Number of characters right on the limit + test_context_dict( + initial_question=f"Y{'A' * 99}", + changed_question=f"W{'A' * 99}", + expected_old=f"Y{'A' * 99}", + expected_new=f"W{'A' * 99}", + ) + + # Over the character limit, with various ways that a shared prefix affects how + # the shortened strings are lined up with each other + test_context_dict( + initial_question=f"Y{'A' * 100}", + changed_question=f"W{'A' * 100}", + expected_old=f"Y{'A' * 60}[35 chars]AAAAA", + expected_new=f"W{'A' * 60}[35 chars]AAAAA", + ) + test_context_dict( + initial_question=f"{'A' * 100}Y", + changed_question=f"{'A' * 100}W", + expected_old=f"AAAAA[13 chars]{'A' * 82}Y", + expected_new=f"AAAAA[13 chars]{'A' * 82}W", + ) + test_context_dict( + initial_question=f"{'A' * 100}Y", + changed_question=f"{'A' * 200}W", + expected_old="AAAAA[90 chars]AAAAAY", + expected_new=f"AAAAA[90 chars]{'A' * 66}[35 chars]AAAAW", + ) + test_context_dict( + initial_question=f"{'A' * 100}Y{'E' * 100}", + changed_question=f"{'A' * 100}W{'E' * 200}", + expected_old=f"AAAAA[90 chars]AAAAAY{'E' * 60}[35 chars]EEEEE", + expected_new=f"AAAAA[90 chars]AAAAAW{'E' * 60}[135 chars]EEEEE", + ) + test_context_dict( + initial_question=f"{'A' * 100}Y{'E' * 200}", + changed_question=f"{'A' * 200}W{'E' * 100}", + expected_old=f"AAAAA[90 chars]AAAAAY{'E' * 60}[135 chars]EEEEE", + expected_new=f"AAAAA[90 chars]{'A' * 66}[135 chars]EEEEE", + ) + + # Only similar prefixes are detected and lined up; + # similar parts later in the strings are not + test_context_dict( + initial_question=f"{'Y' * 100}{'A' * 100}", + changed_question=f"{'W' * 100}{'A' * 100}{'H' * 100}", + expected_old=f"{'Y' * 61}[134 chars]AAAAA", + expected_new=f"{'W' * 61}[234 chars]HHHHH", + ) + + # Both "old" and "new" under the character limit + test_context_dict( + initial_question="A" * 10, + changed_question="A" * 100, + expected_old="A" * 10, + expected_new="A" * 100, + ) + # "new" just over the limit, but with "old" too short to be shortened + test_context_dict( + initial_question="A" * 10, + changed_question="A" * 101, + expected_old="A" * 10, + expected_new=f"{'A' * 71}[25 chars]AAAAA", + ) + # Both "old" and "new" under the character limit + test_context_dict( + initial_question="A" * 99, + changed_question="A" * 100, + expected_old="A" * 99, + expected_new="A" * 100, + ) + # "new" just over the limit, and "old" long enough to be shortened (which is + # done even if it's shorter than the character limit) + test_context_dict( + initial_question="A" * 99, + changed_question="A" * 101, + expected_old=f"AAAAA[13 chars]{'A' * 81}", + expected_new=f"AAAAA[13 chars]{'A' * 83}", + ) + + def test__context_for_delta_changes__preserves_html_safe_strings(self): + def get_context_dict_old_and_new(old_value, new_value) -> Tuple[str, str]: + # The field doesn't really matter, as long as it exists on the model + # passed to `HistoricalRecordContextHelper` + change = ModelChange("question", old_value, new_value) + # (The record args are not (currently) used in the default implementation) + delta = ModelDelta([change], ["question"], None, None) + context_helper = HistoricalRecordContextHelper(Poll, None) + (context_dict,) = context_helper.context_for_delta_changes(delta) + return context_dict["old"], context_dict["new"] + + # Strings not marked as safe should be escaped + old_string = "Hey" + new_string = "Hello" + old, new = get_context_dict_old_and_new(old_string, new_string) + self.assertEqual(old, "<i>Hey</i>") + self.assertEqual(new, "<b>Hello</b>") + # The result should still be marked safe as part of being escaped + self.assertTrue(is_safe_str(old) and is_safe_str(new)) + + # Strings marked as safe should be kept unchanged... + old_safe_string = mark_safe("Hey") + new_safe_string = mark_safe("Hello") + old, new = get_context_dict_old_and_new(old_safe_string, new_safe_string) + self.assertEqual(old, old_safe_string) + self.assertEqual(new, new_safe_string) + self.assertTrue(is_safe_str(old) and is_safe_str(new)) + + # ...also if one is safe and the other isn't... + old_string = "Hey" + new_safe_string = mark_safe("Hello") + old, new = get_context_dict_old_and_new(old_string, new_safe_string) + self.assertEqual(old, "<i>Hey</i>") + self.assertEqual(new, new_safe_string) + self.assertTrue(is_safe_str(old) and is_safe_str(new)) + + # ...unless at least one of them is too long, in which case they should both be + # properly escaped - including mangled tags + old_safe_string = mark_safe(f"
{'A' * 1000}
") + new_safe_string = mark_safe("Hello
") + old, new = get_context_dict_old_and_new(old_safe_string, new_safe_string) + # (`` has been mangled) + expected_old = f"<p><strong>{'A' * 61}[947 chars]></p>" + self.assertEqual(old, expected_old) + self.assertEqual(new, "<p><strong>Hello</strong></p>") + self.assertTrue(is_safe_str(old) and is_safe_str(new)) + + # Unsafe strings inside lists should also be escaped + old_list = ["Hey", "Hey"] + new_list = ["Hello", "Hello"] + old, new = get_context_dict_old_and_new(old_list, new_list) + self.assertEqual(old, "[Hey, <i>Hey</i>]") + self.assertEqual(new, "[<b>Hello</b>, Hello]") + self.assertTrue(is_safe_str(old) and is_safe_str(new)) + + # Safe strings inside lists should be kept unchanged... + old_safe_list = [mark_safe("Hey"), mark_safe("Hey")] + new_safe_list = [mark_safe("Hello"), mark_safe("Hello")] + old, new = get_context_dict_old_and_new(old_safe_list, new_safe_list) + self.assertEqual(old, "[Hey, Hey]") + self.assertEqual(new, "[Hello, Hello]") + self.assertTrue(is_safe_str(old) and is_safe_str(new)) + + # ...but not when not all elements are safe... + old_half_safe_list = [mark_safe("Hey"), "Hey"] + new_half_safe_list = [mark_safe("Hello"), "Hello"] + old, new = get_context_dict_old_and_new(old_half_safe_list, new_half_safe_list) + self.assertEqual(old, "[Hey, <i>Hey</i>]") + self.assertEqual(new, "[<b>Hello</b>, Hello]") + self.assertTrue(is_safe_str(old) and is_safe_str(new)) + + # ...and also not when some of the elements are too long + old_safe_list = [mark_safe("Hey"), mark_safe(f"{'A' * 1000}")] + new_safe_list = [mark_safe("Hello"), mark_safe(f"{'B' * 1000}")] + old, new = get_context_dict_old_and_new(old_safe_list, new_safe_list) + self.assertEqual(old, f"[Hey, <i>{'A' * 53}[947 chars]</i>]") + self.assertEqual(new, f"[<b>Hello</b>, {'B' * 47}[949 chars]BBBB]") + self.assertTrue(is_safe_str(old) and is_safe_str(new)) + + # HTML tags inside too long strings should be properly escaped - including + # mangled tags + old_safe_list = [mark_safe(f"