Skip to content

Commit

Permalink
ConstructQuery refactoring, now it subclasses sqlalchemy.orm.query:Qu…
Browse files Browse the repository at this point in the history
…ery class with ability to create your own ConstructQuery with your own base query class.
  • Loading branch information
vmagamedov committed Feb 12, 2014
1 parent cc25e00 commit 8048db0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 56 deletions.
80 changes: 29 additions & 51 deletions sqlconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def __init__(self, name, *exprs, **kw):

class _QueryPlan(object):

def __init__(self, session=None):
self._session = session
def __init__(self):
self._queries = OrderedSet({None})
self._columns = defaultdict(OrderedSet)
self._children = defaultdict(OrderedSet)
Expand All @@ -115,10 +114,10 @@ def query_columns(self, query):
def query_children(self, query):
return tuple(self._children.get(query) or ())

def process_rows(self, rows):
def process_rows(self, rows, session):
children = self.query_children(None)
queries = (None,) + children
results = (rows,) + tuple(child.process(self, None, rows)
results = (rows,) + tuple(child.process(self, None, rows, session)
for child in children)
return {0: [
tuple(zip(map(self.query_id, queries),
Expand All @@ -142,7 +141,7 @@ def __init__(self, query_plan, query=None, parent=None):
def lookup(self, column):
scope = self
while scope.query:
if column in scope.query:
if scope.query.__contains_column__(column):
return scope
scope = scope.parent
return scope
Expand Down Expand Up @@ -185,7 +184,7 @@ def __hash__(self):
def __eq__(self, other):
return hash(self) == hash(other)

def __contains__(self, column):
def __contains_column__(self, column):
raise NotImplementedError

def __reference__(self):
Expand All @@ -194,7 +193,7 @@ def __reference__(self):
def __requires__(self):
return tuple()

def process(self, query_plan, outer_query, outer_rows):
def process(self, query_plan, outer_query, outer_rows, session):
raise NotImplementedError


Expand Down Expand Up @@ -230,18 +229,18 @@ def __requires__(self):
def __hash__(self):
return self._hash

def __contains__(self, column):
def __contains_column__(self, column):
return any(el.c.contains_column(column)
for el in self._sa_query.statement.froms)

def process(self, query_plan, outer_query, outer_rows):
def process(self, query_plan, outer_query, outer_rows, session):
ext_col_id = query_plan.column_id(outer_query, self._ext_expr)
ext_exprs = [row[ext_col_id] for row in outer_rows]
if ext_exprs:
columns = query_plan.query_columns(self) + (self._int_expr,)
rows = (
self._sa_query
.with_session(query_plan._session.registry())
.with_session(session)
.with_entities(*columns)
.filter(self._int_expr.in_(set(ext_exprs)))
.all()
Expand Down Expand Up @@ -301,18 +300,18 @@ def __requires__(self):
def __hash__(self):
return self._hash

def __contains__(self, column):
def __contains_column__(self, column):
return any(el.c.contains_column(column)
for el in self._sa_query.statement.froms)

def process(self, query_plan, outer_query, outer_rows):
def process(self, query_plan, outer_query, outer_rows, session):
ext_col_id = query_plan.column_id(outer_query, self._ext_expr)
ext_exprs = [row[ext_col_id] for row in outer_rows]
if ext_exprs:
columns = query_plan.query_columns(self) + (self._int_expr,)
rows = (
self._sa_query
.with_session(query_plan._session.registry())
.with_session(session)
.with_entities(*columns)
.filter(self._int_expr.in_(ext_exprs))
.all()
Expand Down Expand Up @@ -375,51 +374,30 @@ def __reduce__(self):
return type(self), (dict(self),)


def _proxy_query_method(unbound_method):
func = _im_func(unbound_method)
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self._query, *args, **kwargs)
return wrapper


def _generative_proxy_query_method(unbound_method):
func = _im_func(unbound_method)
@wraps(func)
def wrapper(self, *args, **kwargs):
cls = type(self)
clone = cls.__new__(cls)
clone.__dict__.update(self.__dict__)
clone._query = func(clone._query, *args, **kwargs)
return clone
return wrapper
class ConstructQueryMixin(object):

def __init__(self, spec):
self._cq_keys, values = zip(*spec.items()) if spec else [(), ()]
self._cq_scope = _Scope(_QueryPlan())
self._cq_procs = [_get_value_processor(self._cq_scope, val)
for val in values]

class ConstructQuery(object):
columns = self._cq_scope.query_plan.query_columns(None)
super(ConstructQueryMixin, self).__init__(columns)

def __init__(self, session, spec):
self._session = session
self._spec = spec
self._keys, self._values = zip(*spec.items()) if spec else [(), ()]
self._scope = _Scope(_QueryPlan(session))
self._processors = [_get_value_processor(self._scope, val)
for val in self._values]
self._query = _SAQuery(self._scope.query_plan.query_columns(None))
def __iter__(self):
rows = list(super(ConstructQueryMixin, self).__iter__())
result = self._cq_scope.query_plan.process_rows(rows, self.session)
for r in self._cq_scope.gen_loop()(result):
values = [proc(r) for proc in self._cq_procs]
yield Object(zip(self._cq_keys, values))

__str__ = _proxy_query_method(_SAQuery.__str__)

join = _generative_proxy_query_method(_SAQuery.join)
outerjoin = _generative_proxy_query_method(_SAQuery.outerjoin)
filter = _generative_proxy_query_method(_SAQuery.filter)
order_by = _generative_proxy_query_method(_SAQuery.order_by)
def construct_query_maker(base_cls):
return type('ConstructQuery', (ConstructQueryMixin, base_cls), {})

all = _im_func(_SAQuery.all)

def __iter__(self):
rows = self._query.with_session(self._session.registry()).all()
result = self._scope.query_plan.process_rows(rows)
for r in self._scope.gen_loop()(result):
yield Object(zip(self._keys, [proc(r) for proc in self._processors]))
ConstructQuery = construct_query_maker(_SAQuery)


class Construct(Bundle):
Expand Down
15 changes: 10 additions & 5 deletions test_sqlconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,10 +695,11 @@ class B(self.base_cls):
session.commit()

query = (
ConstructQuery(session, {
ConstructQuery({
'a_name': A.name,
'b_name': get_(apply_(capitalize, [B.name]), A.b),
})
.with_session(session.registry())
)
self.assertEqual(
tuple(dict(obj) for obj in query.all()),
Expand Down Expand Up @@ -732,10 +733,11 @@ class B(self.base_cls):
session.commit()

query = (
ConstructQuery(session, {
ConstructQuery({
'a_name': A.name,
'b_name': get_(apply_(capitalize, [B.name]), A.b),
})
.with_session(session.registry())
)
self.assertEqual(
tuple(dict(obj) for obj in query.all()),
Expand Down Expand Up @@ -763,10 +765,11 @@ class B(self.base_cls):
session.commit()

query = (
ConstructQuery(session, {
ConstructQuery({
'a_name': A.name,
'b_names': map_(apply_(capitalize, [B.name]), A.b_list),
})
.with_session(session.registry())
)
self.assertEqual(
tuple(dict(obj) for obj in query.all()),
Expand Down Expand Up @@ -802,10 +805,11 @@ class B(self.base_cls):
session.commit()

q1 = (
ConstructQuery(session, {
ConstructQuery({
'a_name': A.name,
'b_names': map_(apply_(capitalize, [B.name]), A.b_list),
})
.with_session(session.registry())
.order_by(A.name)
)
self.assertEqual(
Expand All @@ -819,10 +823,11 @@ class B(self.base_cls):
)

q2 = (
ConstructQuery(session, {
ConstructQuery({
'b_name': B.name,
'a_names': map_(apply_(capitalize, [A.name]), B.a_list),
})
.with_session(session.registry())
.order_by(B.name)
)
self.assertEqual(
Expand Down

0 comments on commit 8048db0

Please sign in to comment.