-
Notifications
You must be signed in to change notification settings - Fork 1
/
schema_auto.py
78 lines (59 loc) · 2.1 KB
/
schema_auto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Automatically create a base graphene query object from SQLAlchemy models.
"""
import cytoolz as cz
import graphene
from graphene import relay
from graphene_sqlalchemy import (
SQLAlchemyObjectType, SQLAlchemyConnectionField)
def _root_name(model):
"""Return the plural (append 's') lower case name of the model for the
query root name.
"""
return model.__name__.lower() + "s"
def _resolver_name(model):
"""Return the name of the resolver function corresponding to the query
root.
"""
return "resolve_" + _root_name(model)
def _type_class(model, with_relay):
"""Return an ObjectType class based on the SQLAlchemy model.
"""
meta_model = dict(model=model)
if with_relay:
meta_interfaces = dict(interfaces=(relay.Node,))
else:
meta_interfaces = dict()
meta = type("Meta", (), cz.merge(meta_model, meta_interfaces))
return type(model.__name__, (SQLAlchemyObjectType,), dict(Meta=meta))
def _root(type_class, with_relay):
if with_relay:
return SQLAlchemyConnectionField(type_class)
else:
return graphene.List(type_class)
def _resolver_all(model, db_session):
def _res(self, args, context, info):
return db_session.query(model).all()
return _res
def _build_roots(models, with_relay):
model_roots = {
_root_name(m): _root(_type_class(m, with_relay), with_relay)
for m in models}
if with_relay:
node_root = {"node": relay.Node.Field()}
else:
node_root = {}
return cz.merge(model_roots, node_root)
def _build_resolvers(models, db_session):
return {
_resolver_name(m): _resolver_all(m, db_session)
for m in models}
def build_query(models, db_session, with_relay=False):
"""Build a query object by registering a List root and corresponding
resolver for each of the given SQLAlchemy models and DB session.
"""
roots = _build_roots(models, with_relay)
if with_relay:
resolvers = dict()
else:
resolvers = _build_resolvers(models, db_session)
return type('Query', (graphene.ObjectType,), cz.merge(roots, resolvers))