Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Quart Server Integration #70

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions graphql_server/aiohttp/graphqlview.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def get_context(self, request):
def get_middleware(self):
return self.middleware

# This method can be static
async def parse_body(self, request):
@staticmethod
async def parse_body(request):
content_type = request.content_type
# request.text() is the aiohttp equivalent to
# request.body.decode("utf8")
Expand Down
7 changes: 4 additions & 3 deletions graphql_server/flask/graphqlview.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def dispatch_request(self):
content_type="application/json",
)

# Flask
def parse_body(self):
@staticmethod
def parse_body():
# We use mimetype here since we don't need the other
# information provided by content_type
content_type = request.mimetype
Expand All @@ -164,7 +164,8 @@ def should_display_graphiql(self):

return self.request_wants_html()

def request_wants_html(self):
@staticmethod
def request_wants_html():
best = request.accept_mimetypes.best_match(["application/json", "text/html"])
return (
best == "text/html"
Expand Down
3 changes: 3 additions & 0 deletions graphql_server/quart/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .graphqlview import GraphQLView

__all__ = ["GraphQLView"]
201 changes: 201 additions & 0 deletions graphql_server/quart/graphqlview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import copy
import sys
from collections.abc import MutableMapping
from functools import partial
from typing import List

from graphql import ExecutionResult
from graphql.error import GraphQLError
from graphql.type.schema import GraphQLSchema
from quart import Response, render_template_string, request
from quart.views import View

from graphql_server import (
GraphQLParams,
HttpQueryError,
encode_execution_results,
format_error_default,
json_encode,
load_json_body,
run_http_query,
)
from graphql_server.render_graphiql import (
GraphiQLConfig,
GraphiQLData,
GraphiQLOptions,
render_graphiql_sync,
)


class GraphQLView(View):
schema = None
root_value = None
context = None
pretty = False
graphiql = False
graphiql_version = None
graphiql_template = None
graphiql_html_title = None
middleware = None
batch = False
enable_async = False
subscriptions = None
headers = None
default_query = None
header_editor_enabled = None
should_persist_headers = None

methods = ["GET", "POST", "PUT", "DELETE"]

format_error = staticmethod(format_error_default)
encode = staticmethod(json_encode)

def __init__(self, **kwargs):
super(GraphQLView, self).__init__()
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)

assert isinstance(
self.schema, GraphQLSchema
), "A Schema is required to be provided to GraphQLView."

def get_root_value(self):
return self.root_value

def get_context(self):
context = (
copy.copy(self.context)
if self.context and isinstance(self.context, MutableMapping)
else {}
)
if isinstance(context, MutableMapping) and "request" not in context:
context.update({"request": request})
return context

def get_middleware(self):
return self.middleware

async def dispatch_request(self):
try:
request_method = request.method.lower()
data = await self.parse_body()

show_graphiql = request_method == "get" and self.should_display_graphiql()
catch = show_graphiql

pretty = self.pretty or show_graphiql or request.args.get("pretty")
all_params: List[GraphQLParams]
execution_results, all_params = run_http_query(
self.schema,
request_method,
data,
query_data=request.args,
batch_enabled=self.batch,
catch=catch,
# Execute options
run_sync=not self.enable_async,
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
)
exec_res = (
[
ex if ex is None or isinstance(ex, ExecutionResult) else await ex
for ex in execution_results
]
if self.enable_async
else execution_results
)
result, status_code = encode_execution_results(
exec_res,
is_batch=isinstance(data, list),
format_error=self.format_error,
encode=partial(self.encode, pretty=pretty), # noqa
)

if show_graphiql:
graphiql_data = GraphiQLData(
result=result,
query=getattr(all_params[0], "query"),
variables=getattr(all_params[0], "variables"),
operation_name=getattr(all_params[0], "operation_name"),
subscription_url=self.subscriptions,
headers=self.headers,
)
graphiql_config = GraphiQLConfig(
graphiql_version=self.graphiql_version,
graphiql_template=self.graphiql_template,
graphiql_html_title=self.graphiql_html_title,
jinja_env=None,
)
graphiql_options = GraphiQLOptions(
default_query=self.default_query,
header_editor_enabled=self.header_editor_enabled,
should_persist_headers=self.should_persist_headers,
)
source = render_graphiql_sync(
data=graphiql_data, config=graphiql_config, options=graphiql_options
)
return await render_template_string(source)

return Response(result, status=status_code, content_type="application/json")

except HttpQueryError as e:
parsed_error = GraphQLError(e.message)
return Response(
self.encode(dict(errors=[self.format_error(parsed_error)])),
status=e.status_code,
headers=e.headers,
content_type="application/json",
)

@staticmethod
async def parse_body():
# We use mimetype here since we don't need the other
# information provided by content_type
content_type = request.mimetype
if content_type == "application/graphql":
refined_data = await request.get_data(raw=False)
return {"query": refined_data}

elif content_type == "application/json":
refined_data = await request.get_data(raw=False)
return load_json_body(refined_data)

elif content_type == "application/x-www-form-urlencoded":
return await request.form

# TODO: Fix this check
elif content_type == "multipart/form-data":
return await request.files

return {}

def should_display_graphiql(self):
if not self.graphiql or "raw" in request.args:
return False

return self.request_wants_html()

@staticmethod
def request_wants_html():
best = request.accept_mimetypes.best_match(["application/json", "text/html"])

# Needed as this was introduced at Quart 0.8.0: https://gitlab.com/pgjones/quart/-/issues/189
def _quality(accept, key: str) -> float:
for option in accept.options:
if accept._values_match(key, option.value):
return option.quality
return 0.0

if sys.version_info >= (3, 7):
return (
best == "text/html"
and request.accept_mimetypes[best]
> request.accept_mimetypes["application/json"]
)
else:
return best == "text/html" and _quality(
request.accept_mimetypes, best
) > _quality(request.accept_mimetypes, "application/json")
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@
"aiohttp>=3.5.0,<4",
]

install_quart_requires = [
"quart>=0.6.15"
]

install_all_requires = \
install_requires + \
install_flask_requires + \
install_sanic_requires + \
install_webob_requires + \
install_aiohttp_requires
install_aiohttp_requires + \
install_quart_requires

with open("graphql_server/version.py") as version_file:
version = search('version = "(.*)"', version_file.read()).group(1)
Expand Down Expand Up @@ -84,6 +89,7 @@
"sanic": install_sanic_requires,
"webob": install_webob_requires,
"aiohttp": install_aiohttp_requires,
"quart": install_quart_requires,
},
include_package_data=True,
zip_safe=False,
Expand Down
8 changes: 4 additions & 4 deletions tests/flask/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@


def create_app(path="/graphql", **kwargs):
app = Flask(__name__)
app.debug = True
app.add_url_rule(
server = Flask(__name__)
server.debug = True
server.add_url_rule(
path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs)
)
return app
return server


if __name__ == "__main__":
Expand Down
30 changes: 6 additions & 24 deletions tests/flask/test_graphqlview.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@pytest.fixture
def app(request):
def app():
# import app factory pattern
app = create_app()

Expand Down Expand Up @@ -269,7 +269,7 @@ def test_supports_post_url_encoded_query_with_string_variables(app, client):
assert response_json(response) == {"data": {"test": "Hello Dolly"}}


def test_supports_post_json_quey_with_get_variable_values(app, client):
def test_supports_post_json_query_with_get_variable_values(app, client):
response = client.post(
url_string(app, variables=json.dumps({"who": "Dolly"})),
data=json_dump_kwarg(query="query helloWho($who: String){ test(who: $who) }",),
Expand Down Expand Up @@ -533,49 +533,34 @@ def test_post_multipart_data(app, client):
def test_batch_allows_post_with_json_encoding(app, client):
response = client.post(
url_string(app),
data=json_dump_kwarg_list(
# id=1,
query="{test}"
),
data=json_dump_kwarg_list(query="{test}"),
content_type="application/json",
)

assert response.status_code == 200
assert response_json(response) == [
{
# 'id': 1,
"data": {"test": "Hello World"}
}
]
assert response_json(response) == [{"data": {"test": "Hello World"}}]


@pytest.mark.parametrize("app", [create_app(batch=True)])
def test_batch_supports_post_json_query_with_json_variables(app, client):
response = client.post(
url_string(app),
data=json_dump_kwarg_list(
# id=1,
query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
content_type="application/json",
)

assert response.status_code == 200
assert response_json(response) == [
{
# 'id': 1,
"data": {"test": "Hello Dolly"}
}
]
assert response_json(response) == [{"data": {"test": "Hello Dolly"}}]


@pytest.mark.parametrize("app", [create_app(batch=True)])
def test_batch_allows_post_with_operation_name(app, client):
response = client.post(
url_string(app),
data=json_dump_kwarg_list(
# id=1,
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
Expand All @@ -591,8 +576,5 @@ def test_batch_allows_post_with_operation_name(app, client):

assert response.status_code == 200
assert response_json(response) == [
{
# 'id': 1,
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
{"data": {"test": "Hello World", "shared": "Hello Everyone"}}
]
Empty file added tests/quart/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions tests/quart/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from quart import Quart

from graphql_server.quart import GraphQLView
from tests.quart.schema import Schema


def create_app(path="/graphql", **kwargs):
server = Quart(__name__)
server.debug = True
server.add_url_rule(
path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs)
)
return server


if __name__ == "__main__":
app = create_app(graphiql=True)
app.run()
Loading