diff --git a/back/boxtribute_server/graph_ql/resolvers.py b/back/boxtribute_server/graph_ql/resolvers.py index 47c5635f77..bf24e30ef3 100644 --- a/back/boxtribute_server/graph_ql/resolvers.py +++ b/back/boxtribute_server/graph_ql/resolvers.py @@ -219,26 +219,33 @@ def resolve_qr_code(obj, _, qr_code=None): @box.field("tags") -def resolve_box_tags(box_obj, _): - return ( - Tag.select() - .join(TagsRelation) - .where( - (TagsRelation.object_id == box_obj.id) - & (TagsRelation.object_type == TaggableObjectType.Box) - ) - ) +def resolve_box_tags(box_obj, info): + return info.context["tags_for_box_loader"].load(box_obj.id) @query.field("product") +def resolve_product(obj, _, id): + product = Product.get_by_id(id) + authorize(permission="product:read", base_id=product.base_id) + return product + + @box.field("product") @unboxed_items_collection.field("product") -def resolve_product(obj, _, id=None): - product = obj.product if id is None else Product.get_by_id(id) - authorize(permission="product:read", base_id=product.base_id) +def resolve_box_product(obj, info): + product = info.context["product_loader"].load(obj.product_id) + # Base-specific authz can be omitted here since it was enforced in the box + # parent-resolver. It's not possible that the box's product is assigned to a + # different base than the box is in + authorize(permission="product:read") return product +@box.field("size") +def resolve_size(box_obj, info): + return info.context["size_loader"].load(box_obj.size_id) + + @query.field("box") @convert_kwargs_to_snake_case def resolve_box(*_, label_identifier): diff --git a/back/boxtribute_server/loaders.py b/back/boxtribute_server/loaders.py new file mode 100644 index 0000000000..0ab819341e --- /dev/null +++ b/back/boxtribute_server/loaders.py @@ -0,0 +1,36 @@ +from collections import defaultdict + +from aiodataloader import DataLoader + +from .enums import TaggableObjectType +from .models.definitions.product import Product +from .models.definitions.size import Size +from .models.definitions.tag import Tag +from .models.definitions.tags_relation import TagsRelation + + +class ProductLoader(DataLoader): + async def batch_load_fn(self, keys): + products = {p.id: p for p in Product.select().where(Product.id << keys)} + return [products.get(i) for i in keys] + + +class SizeLoader(DataLoader): + async def batch_load_fn(self, keys): + sizes = {s.id: s for s in Size.select()} + return [sizes.get(i) for i in keys] + + +class TagsForBoxLoader(DataLoader): + async def batch_load_fn(self, keys): + tags = defaultdict(list) + # maybe need different join type + for relation in ( + TagsRelation.select() + .join(Tag) + .where(TagsRelation.object_type == TaggableObjectType.Box) + ): + tags[relation.object_id].append(relation.tag) + + # keys are in fact box IDs. Return empty list if box has no tags assigned + return [tags.get(i, []) for i in keys] diff --git a/back/boxtribute_server/routes.py b/back/boxtribute_server/routes.py index 0478d96659..a7a9cd0313 100644 --- a/back/boxtribute_server/routes.py +++ b/back/boxtribute_server/routes.py @@ -1,7 +1,8 @@ """Construction of routes for web app and API""" +import asyncio import os -from ariadne import graphql_sync +from ariadne import graphql, graphql_sync from ariadne.constants import PLAYGROUND_HTML from flask import Blueprint, current_app, jsonify, request from flask_cors import cross_origin @@ -9,6 +10,7 @@ from .auth import request_jwt, requires_auth from .exceptions import AuthenticationFailed, format_database_errors from .graph_ql.schema import full_api_schema, query_api_schema +from .loaders import ProductLoader, SizeLoader, TagsForBoxLoader # Blueprint for query-only API. Deployed on the 'api*' subdomains api_bp = Blueprint("api_bp", __name__) @@ -82,15 +84,27 @@ def graphql_playgroud(): @cross_origin(origin="localhost", headers=["Content-Type", "Authorization"]) @requires_auth def graphql_server(): - # Note: Passing the request to the context is optional. - # In Flask, the current request is always accessible as flask.request - success, result = graphql_sync( - full_api_schema, - data=request.get_json(), - context_value=request, - debug=current_app.debug, - introspection=current_app.debug, - error_formatter=format_database_errors, + # Start async event loop, required for DataLoader construction, cf. + # https://github.com/graphql-python/graphql-core/issues/71#issuecomment-620106364 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create DataLoaders and persist them for the time of processing the request + context = { + "product_loader": ProductLoader(), + "size_loader": SizeLoader(), + "tags_for_box_loader": TagsForBoxLoader(), + } + + success, result = loop.run_until_complete( + graphql( + full_api_schema, + data=request.get_json(), + context_value=context, + debug=current_app.debug, + introspection=current_app.debug, + error_formatter=format_database_errors, + ) ) status_code = 200 if success else 400 diff --git a/back/requirements.txt b/back/requirements.txt index 9f80383418..ceccb484a1 100644 --- a/back/requirements.txt +++ b/back/requirements.txt @@ -8,4 +8,5 @@ peewee-moves==2.1.0 python-dateutil==2.8.2 python-dotenv==0.20.0 python-jose==3.3.0 +aiodataloader==0.2.1 gunicorn diff --git a/back/scripts/load-test.js b/back/scripts/load-test.js index a7920b5a16..3ab32d3db7 100644 --- a/back/scripts/load-test.js +++ b/back/scripts/load-test.js @@ -34,8 +34,8 @@ const payload = JSON.stringify({ // query: "query { beneficiaries { elements { firstName } } }", // C) All boxes for base - // query: "query { base(id: 1) { locations { name boxes { totalCount elements { labelIdentifier state size { id label } product { gender name } tags { name id } items } } } } }", - query: "query { location(id: 1) { boxes { elements { product { gender name } } } } }", + query: "query { base(id: 1) { locations { name boxes { totalCount elements { labelIdentifier state size { id label } product { gender name } tags { name id } numberOfItems } } } } }", + // query: "query { location(id: 1) { boxes { elements { product { gender name } } } } }", }); export const options = { diff --git a/back/test/endpoint_tests/test_permissions.py b/back/test/endpoint_tests/test_permissions.py index eddf7af848..08d5d5fcec 100644 --- a/back/test/endpoint_tests/test_permissions.py +++ b/back/test/endpoint_tests/test_permissions.py @@ -223,7 +223,7 @@ def test_invalid_permission_for_shipment_base(read_only_client, mocker, field): assert_forbidden_request(read_only_client, query, value={field: None}) -@pytest.mark.parametrize("field", ["place", "product", "qrCode"]) +@pytest.mark.parametrize("field", ["place", "qrCode"]) def test_invalid_permission_for_box_field(read_only_client, mocker, default_box, field): # verify missing field:read permission mocker.patch("jose.jwt.decode").return_value = create_jwt_payload(