Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance query performance using batch-loading #881

Merged
merged 6 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions back/boxtribute_server/business_logic/beneficiary/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,6 @@ def resolve_beneficiary_active(beneficiary_obj, _):


@beneficiary.field("base")
def resolve_beneficiary_base(beneficiary_obj, _):
def resolve_beneficiary_base(beneficiary_obj, info):
authorize(permission="base:read", base_id=beneficiary_obj.base_id)
return beneficiary_obj.base
return info.context["base_loader"].load(beneficiary_obj.base_id)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@
transfer_agreement = ObjectType("TransferAgreement")


@transfer_agreement.field("sourceOrganisation")
def resolve_agreement_source_organisation(transfer_agreement_obj, info):
return info.context["organisation_loader"].load(
transfer_agreement_obj.source_organisation_id
)


@transfer_agreement.field("targetOrganisation")
def resolve_agreement_target_organisation(transfer_agreement_obj, info):
return info.context["organisation_loader"].load(
transfer_agreement_obj.target_organisation_id
)


@transfer_agreement.field("sourceBases")
def resolve_transfer_agreement_source_bases(transfer_agreement_obj, _):
source_bases = retrieve_transfer_agreement_bases(
Expand Down Expand Up @@ -42,3 +56,18 @@ def resolve_transfer_agreement_shipments(transfer_agreement_obj, _):
authorized_bases_filter(Shipment, base_fk_field_name="source_base")
| authorized_bases_filter(Shipment, base_fk_field_name="target_base"),
)


@transfer_agreement.field("requestedBy")
def resolve_shipment_requested_by(transfer_agreement_obj, info):
return info.context["user_loader"].load(transfer_agreement_obj.requested_by_id)


@transfer_agreement.field("acceptedBy")
def resolve_shipment_accepted_by(transfer_agreement_obj, info):
return info.context["user_loader"].load(transfer_agreement_obj.accepted_by_id)


@transfer_agreement.field("terminatedBy")
def resolve_shipment_terminated_by(transfer_agreement_obj, info):
return info.context["user_loader"].load(transfer_agreement_obj.terminated_by_id)
5 changes: 5 additions & 0 deletions back/boxtribute_server/business_logic/core/base/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
base = ObjectType("Base")


@base.field("organisation")
def resolve_base_organisation(base_obj, info):
return info.context["organisation_loader"].load(base_obj.organisation_id)


@base.field("products")
def resolve_base_products(base_obj, *_):
authorize(permission="product:read", base_id=base_obj.id)
Expand Down
6 changes: 6 additions & 0 deletions back/boxtribute_server/business_logic/tag/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@ def resolve_tag_tagged_resources(tag_obj, _):
authorized_bases_filter(Beneficiary),
)
) + list(Box.select().where(Box.id << [r.object_id for r in box_relations]))


@tag.field("base")
def resolve_tag_base(tag_obj, info):
authorize(permission="base:read", base_id=tag_obj.base_id)
return info.context["base_loader"].load(tag_obj.base_id)
5 changes: 2 additions & 3 deletions back/boxtribute_server/business_logic/user/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from ...authz import authorize, authorized_bases_filter
from ...models.definitions.base import Base
from ...models.definitions.organisation import Organisation

user = ObjectType("User")

Expand All @@ -20,7 +19,7 @@ def resolve_user_email(user_obj, _):


@user.field("organisation")
def resolve_user_organisation(user_obj, _):
def resolve_user_organisation(user_obj, info):
if user_obj.id != g.user.id:
# If the queried user is different from the current user, we don't have a way
# yet to fetch information about that user's organisation
Expand All @@ -30,4 +29,4 @@ def resolve_user_organisation(user_obj, _):
# God user does not belong to an organisation
return

return Organisation.get_by_id(g.user.organisation_id)
return info.context["organisation_loader"].load(g.user.organisation_id)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ def resolve_product_gender(product_obj, _):


@product.field("base")
def resolve_product_base(product_obj, _):
def resolve_product_base(product_obj, info):
authorize(permission="base:read", base_id=product_obj.base_id)
return product_obj.base
return info.context["base_loader"].load(product_obj.base_id)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
@qr_code.field("box")
def resolve_qr_code_box(qr_code_obj, _):
try:
box = Box.select().join(Location).where(Box.qr_code == qr_code_obj.id).get()
box = (
Box.select(Box, Location.base)
.join(Location)
.where(Box.qr_code == qr_code_obj.id)
.get()
)
authorize(permission="stock:read", base_id=box.location.base_id)
except Box.DoesNotExist:
box = None
Expand Down
2 changes: 2 additions & 0 deletions back/boxtribute_server/graph_ql/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BaseLoader,
BoxLoader,
LocationLoader,
OrganisationLoader,
ProductCategoryLoader,
ProductLoader,
ShipmentDetailForBoxLoader,
Expand All @@ -33,6 +34,7 @@ async def run():
"base_loader": BaseLoader(),
"box_loader": BoxLoader(),
"location_loader": LocationLoader(),
"organisation_loader": OrganisationLoader(),
"product_category_loader": ProductCategoryLoader(),
"product_loader": ProductLoader(),
"shipment_detail_for_box_loader": ShipmentDetailForBoxLoader(),
Expand Down
105 changes: 61 additions & 44 deletions back/boxtribute_server/graph_ql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..models.definitions.base import Base
from ..models.definitions.box import Box
from ..models.definitions.location import Location
from ..models.definitions.organisation import Organisation
from ..models.definitions.product import Product
from ..models.definitions.product_category import ProductCategory
from ..models.definitions.shipment import Shipment
Expand All @@ -16,6 +17,7 @@
from ..models.definitions.tag import Tag
from ..models.definitions.tags_relation import TagsRelation
from ..models.definitions.user import User
from ..utils import convert_pascal_to_snake_case


class DataLoader(_DataLoader):
Expand All @@ -27,37 +29,73 @@ def load(self, key):
return super().load(key)


class BaseLoader(DataLoader):
async def batch_load_fn(self, keys):
bases = {b.id: b for b in Base.select().where(Base.id << keys)}
return [bases.get(i) for i in keys]
class SimpleDataLoader(DataLoader):
"""Custom implementation that batch-loads all requested rows of the specified data
model, optionally enforcing authorization for the resource.
Authorization may be skipped for base-specific resources.
"""

def __init__(self, model, skip_authorize=False):
super().__init__()
self.model = model
self.skip_authorize = skip_authorize

class ProductLoader(DataLoader):
async def batch_load_fn(self, keys):
products = {p.id: p for p in Product.select().where(Product.id << keys)}
return [products.get(i) for i in keys]
async def batch_load_fn(self, ids):
if not self.skip_authorize:
resource = convert_pascal_to_snake_case(self.model.__name__)
# work-around for inconsistent RBP naming
if resource == "product_category":
resource = "category"
permission = f"{resource}:read"
authorize(permission=permission)

rows = {r.id: r for r in self.model.select().where(self.model.id << ids)}
return [rows.get(i) for i in ids]

class LocationLoader(DataLoader):
async def batch_load_fn(self, keys):
locations = {
loc.id: loc for loc in Location.select().where(Location.id << keys)
}
return [locations.get(i) for i in keys]

class BaseLoader(SimpleDataLoader):
def __init__(self):
super().__init__(Base, skip_authorize=True)

class SizeLoader(DataLoader):
async def batch_load_fn(self, keys):
authorize(permission="size:read")
sizes = {s.id: s for s in Size.select()}
return [sizes.get(i) for i in keys]

class ProductLoader(SimpleDataLoader):
def __init__(self):
super().__init__(Product, skip_authorize=True)


class LocationLoader(SimpleDataLoader):
def __init__(self):
super().__init__(Location, skip_authorize=True)


class BoxLoader(SimpleDataLoader):
def __init__(self):
super().__init__(Box, skip_authorize=True)


class SizeLoader(SimpleDataLoader):
def __init__(self):
super().__init__(Size)


class OrganisationLoader(SimpleDataLoader):
def __init__(self):
super().__init__(Organisation)

class BoxLoader(DataLoader):
async def batch_load_fn(self, keys):
boxes = {b.id: b for b in Box.select().where(Box.id << keys)}
return [boxes.get(i) for i in keys]

class UserLoader(SimpleDataLoader):
def __init__(self):
super().__init__(User)


class ProductCategoryLoader(SimpleDataLoader):
def __init__(self):
super().__init__(ProductCategory)


class SizeRangeLoader(SimpleDataLoader):
def __init__(self):
super().__init__(SizeRange)


class ShipmentLoader(DataLoader):
Expand Down Expand Up @@ -106,20 +144,6 @@ async def batch_load_fn(self, keys):
return [details.get(i) for i in keys]


class ProductCategoryLoader(DataLoader):
async def batch_load_fn(self, keys):
authorize(permission="category:read")
categories = {c.id: c for c in ProductCategory.select()}
return [categories.get(i) for i in keys]


class SizeRangeLoader(DataLoader):
async def batch_load_fn(self, keys):
authorize(permission="size_range:read")
ranges = {s.id: s for s in SizeRange.select()}
return [ranges.get(i) for i in keys]


class SizesForSizeRangeLoader(DataLoader):
async def batch_load_fn(self, keys):
authorize(permission="size:read")
Expand All @@ -129,10 +153,3 @@ async def batch_load_fn(self, keys):
sizes[size.size_range_id].append(size)
# Keys are in fact size range IDs. Return empty list if size range has no sizes
return [sizes.get(i, []) for i in keys]


class UserLoader(DataLoader):
async def batch_load_fn(self, keys):
authorize(permission="user:read")
users = {s.id: s for s in User.select().where(User.id << keys)}
return [users.get(i) for i in keys]
Loading