From addd207a923d69cdf17e0d588041f65210f554b8 Mon Sep 17 00:00:00 2001 From: dominikbak Date: Tue, 11 Jul 2023 11:32:01 +0200 Subject: [PATCH] Add ZEN_QUERIES_DISABLED_HANDLER setting --- README.md | 5 +++++ zen_queries/decorators.py | 16 +++++++++++++++- zen_queries/tests/tests.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index cedca20..e0a8095 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,11 @@ There are {{ pizzas.count }} pizzas. {% end_queries_dangerously_enabled %} ``` +#### Custom disabled queries handler +Zen Queries provides flexibility with the capacity to set a custom handler that can exhibit any behavior you desire when queries are disabled. This custom handler should be a callable (a function or a method). You can configure this custom handler in Zen Queries by specifying the Python import path to the callable as a string in the ZEN_QUERIES_DISABLED_HANDLER setting. + +By default, if queries are disabled and an attempt is made to execute a query, Zen Queries raises an exception. However, with a custom handler, you can override this behavior to better suit your application's specific needs. + ### Permissions gotcha Accessing permissions in your templates (via the `{{ perms }}` template variable) can be a source of queries at template-render time. Fortunately, Django's permission checks are [cached by the `ModelBackend`](https://docs.djangoproject.com/en/2.2/topics/auth/default/#permission-caching), which can be pre-populated by calling `request.user.get_all_permissions()` in the view, before rendering the template. diff --git a/zen_queries/decorators.py b/zen_queries/decorators.py index 8698d40..02f5400 100644 --- a/zen_queries/decorators.py +++ b/zen_queries/decorators.py @@ -1,6 +1,10 @@ from contextlib import contextmanager +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured from django.db import connections +import importlib + class QueriesDisabledError(Exception): pass @@ -10,9 +14,19 @@ def _raise_exception(execute, sql, params, many, context): raise QueriesDisabledError(sql) +def _get_custom_wrapper(): + custom_wrapper_path = getattr(settings, "ZEN_QUERIES_DISABLED_HANDLER", None) + if custom_wrapper_path: + module_path, function_name = custom_wrapper_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, function_name, None) + + def _disable_queries(): + custom_wrapper = _get_custom_wrapper() + for connection in connections.all(): - connection.execute_wrappers.append(_raise_exception) + connection.execute_wrappers.append(custom_wrapper or _raise_exception) def _enable_queries(): diff --git a/zen_queries/tests/tests.py b/zen_queries/tests/tests.py index ee6ab0b..b38f75e 100644 --- a/zen_queries/tests/tests.py +++ b/zen_queries/tests/tests.py @@ -1,3 +1,5 @@ +from django.conf import settings +from django.db import connections from django.shortcuts import render as django_render from django.test import TestCase from rest_framework import serializers @@ -17,6 +19,8 @@ ) from zen_queries.tests.models import Widget +import importlib + class ContextManagerTestCase(TestCase): def test_queries_disabled(self): @@ -190,3 +194,33 @@ def test_serializer_with_list(self): self.assertTrue( isinstance(view.get_serializer(), QueriesDisabledSerializerMixin) ) + + +def custom_queries_disabled_handler(): + raise QueriesDisabledError("Custom queries disabled handler") + + +class CustomQueryWrapperTestCase(TestCase): + def setUp(self): + super().setUp() + self.original_wrapper = getattr(settings, "ZEN_QUERIES_DISABLED_HANDLER", None) + settings.ZEN_QUERIES_DISABLED_HANDLER = ( + "zen_queries.tests.tests.custom_queries_disabled_handler" + ) + + def tearDown(self): + settings.ZEN_QUERIES_DISABLED_HANDLER = self.original_wrapper + super().tearDown() + + def test_custom_wrapper(self): + custom_wrapper_path = getattr(settings, "ZEN_QUERIES_DISABLED_HANDLER", None) + self.assertIsNotNone(custom_wrapper_path) + + module_path, function_name = custom_wrapper_path.rsplit(".", 1) + + module = importlib.import_module(module_path) + custom_wrapper = getattr(module, function_name) + + with queries_disabled(): + for connection in connections.all(): + self.assertIn(custom_wrapper, connection.execute_wrappers)