Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flask asyncio support for dataloaders #66

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 9 additions & 24 deletions graphql_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from graphql.error import format_error as format_error_default
from graphql.execution import ExecutionResult, execute
from graphql.language import OperationType, parse
from graphql.pyutils import AwaitableOrValue
from graphql.pyutils import AwaitableOrValue, is_awaitable
from graphql.type import GraphQLSchema, validate_schema
from graphql.utilities import get_operation_ast
from graphql.validation import ASTValidationRule, validate
Expand Down Expand Up @@ -99,9 +99,7 @@ def run_http_query(

if not is_batch:
if not isinstance(data, (dict, MutableMapping)):
raise HttpQueryError(
400, f"GraphQL params should be a dict. Received {data!r}."
)
raise HttpQueryError(400, f"GraphQL params should be a dict. Received {data!r}.")
data = [data]
elif not batch_enabled:
raise HttpQueryError(400, "Batch GraphQL requests are not enabled.")
Expand All @@ -114,15 +112,10 @@ def run_http_query(
if not is_batch:
extra_data = query_data or {}

all_params: List[GraphQLParams] = [
get_graphql_params(entry, extra_data) for entry in data
]
all_params: List[GraphQLParams] = [get_graphql_params(entry, extra_data) for entry in data]

results: List[Optional[AwaitableOrValue[ExecutionResult]]] = [
get_response(
schema, params, catch_exc, allow_only_query, run_sync, **execute_options
)
for params in all_params
get_response(schema, params, catch_exc, allow_only_query, run_sync, **execute_options) for params in all_params
]
return GraphQLResponse(results, all_params)

Expand Down Expand Up @@ -160,10 +153,7 @@ def encode_execution_results(
Returns a ServerResponse tuple with the serialized response as the first item and
a status code of 200 or 400 in case any result was invalid as the second item.
"""
results = [
format_execution_result(execution_result, format_error)
for execution_result in execution_results
]
results = [format_execution_result(execution_result, format_error) for execution_result in execution_results]
result, status_codes = zip(*results)
status_code = max(status_codes)

Expand Down Expand Up @@ -274,14 +264,11 @@ def get_response(
if operation != OperationType.QUERY.value:
raise HttpQueryError(
405,
f"Can only perform a {operation} operation"
" from a POST request.",
f"Can only perform a {operation} operation" " from a POST request.",
headers={"Allow": "POST"},
)

validation_errors = validate(
schema, document, rules=validation_rules, max_errors=max_errors
)
validation_errors = validate(schema, document, rules=validation_rules, max_errors=max_errors)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)

Expand All @@ -290,7 +277,7 @@ def get_response(
document,
variable_values=params.variables,
operation_name=params.operation_name,
is_awaitable=assume_not_awaitable if run_sync else None,
is_awaitable=assume_not_awaitable if run_sync else is_awaitable,
**kwargs,
)

Expand All @@ -317,9 +304,7 @@ def format_execution_result(
fe = [format_error(e) for e in execution_result.errors] # type: ignore
response = {"errors": fe}

if execution_result.errors and any(
not getattr(e, "path", None) for e in execution_result.errors
):
if execution_result.errors and any(not getattr(e, "path", None) for e in execution_result.errors):
status_code = 400
else:
response["data"] = execution_result.data
Expand Down
68 changes: 40 additions & 28 deletions graphql_server/flask/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import copy
from collections.abc import MutableMapping
from functools import partial
from typing import List

from flask import Response, render_template_string, request
from flask.views import View
from graphql import ExecutionResult
from graphql.error import GraphQLError
from graphql.pyutils import is_awaitable
from graphql.type.schema import GraphQLSchema

from graphql_server import (
Expand Down Expand Up @@ -41,6 +44,7 @@ class GraphQLView(View):
default_query = None
header_editor_enabled = None
should_persist_headers = None
enable_async = False

methods = ["GET", "POST", "PUT", "DELETE"]

Expand All @@ -53,26 +57,46 @@ def __init__(self, **kwargs):
if hasattr(self, key):
setattr(self, key, value)

assert isinstance(
self.schema, GraphQLSchema
), "A Schema is required to be provided to GraphQLView."
assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView."

def get_root_value(self):
return self.root_value

def get_context(self):
context = (
copy.copy(self.context)
if self.context and isinstance(self.context, MutableMapping)
else {}
)
context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {}
if isinstance(context, MutableMapping) and "request" not in context:
context.update({"request": request})
return context

def get_middleware(self):
return self.middleware

def get_async_execution_results(self, request_method, data, catch):
async def await_execution_results():
execution_results, all_params = self.run_http_query(request_method, data, catch)
return [
ex if ex is None or not is_awaitable(ex) else await ex
for ex in execution_results
], all_params

q = asyncio.run(await_execution_results())
return q

def run_http_query(self, request_method, data, catch):
return run_http_query(
self.schema,
request_method,
data,
query_data=request.args,
batch_enabled=self.batch,
catch=catch,
# Execute options
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
run_sync=not self.enable_async,
)

def dispatch_request(self):
try:
request_method = request.method.lower()
Expand All @@ -84,18 +108,12 @@ def dispatch_request(self):
pretty = self.pretty or show_graphiql or request.args.get("pretty")

all_params: List[GraphQLParams]
execution_results, all_params = run_http_query(
self.schema,
request_method,
data,
query_data=request.args,
batch_enabled=self.batch,
catch=catch,
# Execute options
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
)

if self.enable_async:
execution_results, all_params = self.get_async_execution_results(request_method, data, catch)
else:
execution_results, all_params = self.run_http_query(request_method, data, catch)

result, status_code = encode_execution_results(
execution_results,
is_batch=isinstance(data, list),
Expand Down Expand Up @@ -123,9 +141,7 @@ def dispatch_request(self):
header_editor_enabled=self.header_editor_enabled,
should_persist_headers=self.should_persist_headers,
)
source = render_graphiql_sync(
data=graphiql_data, config=graphiql_config, options=graphiql_options
)
source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options)
return render_template_string(source)

return Response(result, status=status_code, content_type="application/json")
Expand Down Expand Up @@ -167,8 +183,4 @@ def should_display_graphiql(self):
@staticmethod
def request_wants_html():
best = request.accept_mimetypes.best_match(["application/json", "text/html"])
return (
best == "text/html"
and request.accept_mimetypes[best]
> request.accept_mimetypes["application/json"]
)
return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"]
9 changes: 5 additions & 4 deletions tests/flask/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from flask import Flask

from graphql_server.flask import GraphQLView
from tests.flask.schema import Schema
from tests.flask.schema import AsyncSchema, Schema


def create_app(path="/graphql", **kwargs):
server = Flask(__name__)
server.debug = True
server.add_url_rule(
path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs)
)
if kwargs.get("enable_async", None):
server.add_url_rule(path, view_func=GraphQLView.as_view("graphql", schema=AsyncSchema, **kwargs))
else:
server.add_url_rule(path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs))
return server


Expand Down
25 changes: 22 additions & 3 deletions tests/flask/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,28 @@ def resolve_raises(*_):

MutationRootType = GraphQLObjectType(
name="MutationRoot",
fields={
"writeTest": GraphQLField(type_=QueryRootType, resolve=lambda *_: QueryRootType)
},
fields={"writeTest": GraphQLField(type_=QueryRootType, resolve=lambda *_: QueryRootType)},
)

Schema = GraphQLSchema(QueryRootType, MutationRootType)


async def async_resolver(obj, info):
return "async"


AsyncQueryRootType = GraphQLObjectType(
name="QueryRoot",
fields={
"sync": GraphQLField(GraphQLNonNull(GraphQLString), resolve=lambda obj, info: "sync"),
"nsync": GraphQLField(GraphQLNonNull(GraphQLString), resolve=async_resolver),
},
)
AsyncMutationRootType = GraphQLObjectType(
name="MutationRoot",
fields={
"sync": GraphQLField(type_=GraphQLString, resolve=lambda obj, info: "sync"),
"nsync": GraphQLField(type_=GraphQLString, resolve=async_resolver),
},
)
AsyncSchema = GraphQLSchema(AsyncQueryRootType, AsyncMutationRootType)
Loading