diff --git a/ariadne_django/views/base.py b/ariadne_django/views/base.py index 9dcc153..ded22d4 100644 --- a/ariadne_django/views/base.py +++ b/ariadne_django/views/base.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Type, Union from django.conf import settings from django.http import HttpRequest @@ -14,8 +14,7 @@ from ariadne.types import ContextValue, ErrorFormatter, ExtensionList, RootValue, ValidationRules from graphql import GraphQLSchema -from graphql.execution import MiddlewareManager - +from graphql.execution import MiddlewareManager, ExecutionContext Extensions = Union[Callable[[Any, Optional[ContextValue]], ExtensionList], ExtensionList] @@ -44,6 +43,7 @@ class BaseGraphQLView(TemplateResponseMixin, ContextMixin, View): error_formatter: Optional[ErrorFormatter] = None extensions: Optional[Extensions] = None middleware: Optional[MiddlewareManager] = None + execution_context_class: Optional[Type[ExecutionContext]] = None def _get(self, request: HttpRequest, *args, **kwargs): # pylint: disable=unused-argument options = DEFAULT_PLAYGROUND_OPTIONS.copy() @@ -99,9 +99,10 @@ def get_kwargs_graphql(self, request: HttpRequest) -> dict: "error_formatter": self.error_formatter or format_error, "extensions": extensions, "middleware": self.middleware, + "execution_context_class": self.execution_context_class or ExecutionContext, } - def get_context_for_request(self, request: HttpRequest) -> Optional[ContextValue]: + def get_context_for_request(self, request: HttpRequest, data=None) -> Optional[ContextValue]: if callable(self.context_value): return self.context_value(request) # pylint: disable=not-callable return self.context_value or {"request": request} diff --git a/requirements-dev.txt b/requirements-dev.txt index 77a7b09..bda1f9b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ -r requirements.txt black==22.6.0 -codecov==2.1.11 +codecov>=2.1.11 django-stubs==1.7.0 isort==5.7.0 mypy==0.812 diff --git a/requirements.txt b/requirements.txt index f601f9e..b408462 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -ariadne>=0.13.0 +ariadne>=0.18.0 django>=2.2 python-dateutil==2.8.1 diff --git a/setup.py b/setup.py index c1e6809..1eba6c9 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Topic :: Software Development :: Libraries :: Python Modules", ] @@ -34,7 +35,7 @@ include_package_data=True, install_requires=[ "django>=2.2.0", - "ariadne>=0.13.0", + "ariadne>=0.18.0", ], classifiers=CLASSIFIERS, platforms=["any"], diff --git a/tests/conftest.py b/tests/conftest.py index 19eab7f..0dda2e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from ariadne import MutationType, QueryType, SubscriptionType, make_executable_schema, upload_scalar import pytest -from graphql import ValidationRule +from graphql import ExecutionContext, ValidationRule def pytest_configure(): @@ -190,3 +190,11 @@ class NoopRule(ValidationRule): pass return NoopRule + + +@pytest.fixture +def execution_context_class(): + class CustomExecutionContext(ExecutionContext): + pass + + return CustomExecutionContext diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 374d5f2..6464db0 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -1,10 +1,18 @@ import json -from unittest.mock import ANY, Mock +from typing import List +from unittest.mock import ANY, Mock, call from django.test import override_settings - -from ariadne.types import ExtensionSync - +from graphql.language import FieldNode +from graphql.pyutils import Path + +try: + from ariadne.types import ExtensionSync +except ImportError: + # From ariadne 0.20 Extension supports both sync and async contexts + # https://github.com/mirumee/ariadne/blob/main/CHANGELOG.md#020-2023-06-21 + from ariadne.types import Extension as ExtensionSync +from graphql import ExecutionContext, GraphQLBoolean, GraphQLResolveInfo, GraphQLScalarType import pytest from ariadne_django.views import GraphQLView @@ -56,6 +64,18 @@ def test_custom_context_value_function_result_is_passed_to_resolvers(request_fac assert data == {"data": {"testContext": "TEST-CONTEXT"}} +def test_custom_execution_context_is_used_to_execute_operation(mocker, request_factory, schema, execution_context_class): + spy_execution_context_execute_operation = mocker.spy(execution_context_class, 'execute_operation') + + execute_query( + request_factory, + schema, + {"query": "{ status }"}, + execution_context_class=execution_context_class, + ) + spy_execution_context_execute_operation.assert_called_once() + + def test_custom_root_value_is_passed_to_resolvers(request_factory, schema): data = execute_query( request_factory, @@ -81,7 +101,7 @@ def test_custom_root_value_function_is_called_with_context_value(request_factory context_value={"test": "TEST-CONTEXT"}, root_value=get_root_value, ) - get_root_value.assert_called_once_with({"test": "TEST-CONTEXT"}, ANY) + get_root_value.assert_called_once_with({"test": "TEST-CONTEXT"}, ANY, ANY, ANY) def test_custom_validation_rule_is_called_by_query_validation(mocker, request_factory, schema, validation_rule):