Skip to content

Commit

Permalink
Fixed sub-queries and added tests for nested sub-queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
vmagamedov committed Feb 14, 2014
1 parent 8048db0 commit 41062e3
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 49 deletions.
65 changes: 26 additions & 39 deletions sqlconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import inspect
from operator import attrgetter
from itertools import chain
from functools import partial, wraps
from functools import partial
from collections import defaultdict

import sqlalchemy
Expand Down Expand Up @@ -116,14 +116,12 @@ def query_children(self, query):

def process_rows(self, rows, session):
children = self.query_children(None)
queries = (None,) + children
results = (rows,) + tuple(child.process(self, None, rows, session)
for child in children)
return {0: [
tuple(zip(map(self.query_id, queries),
result_row))
for result_row in zip(*results)
]}

_results = [[((self.query_id(None), row),) for row in rows]]
_results.extend(child.process(self, None, rows, session)
for child in children)

return {0: [tuple(chain(*r)) for r in zip(*_results)]}


class _Scope(object):
Expand Down Expand Up @@ -170,11 +168,6 @@ def loop(result, _query_id=self.query_plan.query_id(self.query)):
yield dict(chain(_iteritems(result), item))
return loop

def gen_getter(self):
def getter(result, _query_id=self.query_plan.query_id(self.query)):
return dict(chain(_iteritems(result), result[_query_id]))
return getter


class _QueryBase(object):

Expand Down Expand Up @@ -250,24 +243,22 @@ def process(self, query_plan, outer_query, outer_rows, session):

children = query_plan.query_children(self)

queries = (self,) + children
results = (rows,) + tuple(child.process(query_plan, self, rows)
for child in children)
_results = [[((query_plan.query_id(self), row),) for row in rows]]
_results.extend(child.process(query_plan, self, rows, session)
for child in children)

col_id = query_plan.column_id(self, self._int_expr)
mapping = {}
for result_row in zip(*results):
mapping[result_row[0][col_id]] = \
tuple(zip(map(query_plan.query_id, queries),
result_row))
for r in zip(*_results):
r = tuple(chain(*r))
_, row = r[0]
mapping[row[-1]] = r

nulls = (
(query_plan.query_id(self),
tuple(None for _ in query_plan.query_columns(self))),
)

return [mapping.get(ext_expr) if ext_expr in mapping else nulls
for ext_expr in ext_exprs]
return [mapping.get(ext_expr, nulls) for ext_expr in ext_exprs]


class _RelativeCollectionSubQuery(_QueryBase):
Expand Down Expand Up @@ -321,18 +312,18 @@ def process(self, query_plan, outer_query, outer_rows, session):

children = query_plan.query_children(self)

queries = (self,) + children
results = (rows,) + tuple(child.process(query_plan, self, rows)
for child in children)
_results = [[((query_plan.query_id(self), row),) for row in rows]]
_results.extend(child.process(query_plan, self, rows, session)
for child in children)

col_id = query_plan.column_id(self, self._int_expr)
groups = defaultdict(list)
for result_row in zip(*results):
groups[result_row[0][col_id]].append(
tuple(zip(map(query_plan.query_id, queries),
result_row))
)
return [groups[ext_expr] for ext_expr in ext_exprs]
for r in zip(*_results):
r = tuple(chain(*r))
_, row = r[0]
groups[row[-1]].append(r)

return [((query_plan.query_id(self), groups[ext_expr]),)
for ext_expr in ext_exprs]


class Processable(object):
Expand Down Expand Up @@ -506,11 +497,7 @@ def __init__(self, func, obj):

def __processor__(self, scope):
nested_scope = scope.nested(self._sub_query)
func_proc = _get_value_processor(nested_scope, self._func)
getter = nested_scope.gen_getter()
def process(result):
return func_proc(getter(result))
return process
return _get_value_processor(nested_scope, self._func)


class _arg_helper(object):
Expand Down
172 changes: 162 additions & 10 deletions test_sqlconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,28 +688,29 @@ class B(self.base_cls):
session = self.init()
b1, b2, b3 = B(name='b1'), B(name='b2'), B(name='b3')
session.add_all([
A(name='a1', b=b1), A(name='a2', b=b1), A(name='a3', b=b1),
A(name='a4', b=b2), A(name='a5', b=b2), A(name='a6', b=b2),
A(name='a7', b=b3), A(name='a8', b=b3), A(name='a9', b=b3),
A(name='a1', b=b1), A(name='a2', b=b1), A(name='a3'),
A(name='a4', b=b2), A(name='a5'), A(name='a6', b=b2),
A(name='a7'), A(name='a8', b=b3), A(name='a9', b=b3),
])
session.commit()

query = (
ConstructQuery({
'a_name': A.name,
'b_name': get_(apply_(capitalize, [B.name]), A.b),
'b_name': get_(if_(B.id, apply_(capitalize, [B.name]), '~'),
A.b),
})
.with_session(session.registry())
)
self.assertEqual(
tuple(dict(obj) for obj in query.all()),
({'a_name': 'a1', 'b_name': 'B1'},
{'a_name': 'a2', 'b_name': 'B1'},
{'a_name': 'a3', 'b_name': 'B1'},
{'a_name': 'a3', 'b_name': '~'},
{'a_name': 'a4', 'b_name': 'B2'},
{'a_name': 'a5', 'b_name': 'B2'},
{'a_name': 'a5', 'b_name': '~'},
{'a_name': 'a6', 'b_name': 'B2'},
{'a_name': 'a7', 'b_name': 'B3'},
{'a_name': 'a7', 'b_name': '~'},
{'a_name': 'a8', 'b_name': 'B3'},
{'a_name': 'a9', 'b_name': 'B3'}),
)
Expand All @@ -727,22 +728,24 @@ class B(self.base_cls):
session = self.init()
session.add_all([
A(name='a1', b=B(name='b1')),
A(name='a2', b=B(name='b2')),
A(name='a2'),
B(name='b2'),
A(name='a3', b=B(name='b3')),
])
session.commit()

query = (
ConstructQuery({
'a_name': A.name,
'b_name': get_(apply_(capitalize, [B.name]), A.b),
'b_name': get_(if_(B.id, apply_(capitalize, [B.name]), '~'),
A.b),
})
.with_session(session.registry())
)
self.assertEqual(
tuple(dict(obj) for obj in query.all()),
({'a_name': 'a1', 'b_name': 'B1'},
{'a_name': 'a2', 'b_name': 'B2'},
{'a_name': 'a2', 'b_name': '~'},
{'a_name': 'a3', 'b_name': 'B3'}),
)

Expand Down Expand Up @@ -839,3 +842,152 @@ class B(self.base_cls):
('b4', {'A1', 'A2', 'A3'}),
),
)

def test_nested(self):
"""
A <- B -> C -> D <- E
"""
class A(self.base_cls):
name = Column(String)

class B(self.base_cls):
name = Column(String)
a_id = Column('a_id', Integer, ForeignKey('a.id'))
a = relationship('A', backref='b_list')
c_id = Column('c_id', Integer, ForeignKey('c.id'))
c = relationship('C', backref='b_list')

class C(self.base_cls):
name = Column(String)
d_id = Column('d_id', Integer, ForeignKey('d.id'))
d = relationship('D', backref='c_list')

class D(self.base_cls):
name = Column(String)

class E(self.base_cls):
name = Column(String)
d_id = Column('d_id', Integer, ForeignKey('d.id'))
d = relationship('D', backref='e_list')

session = self.init()
a1, a2, a3 = A(name='a1'), A(name='a2'), A(name='a3')
d1 = D(name='d1',
c_list=[C(name='c1',
b_list=[B(name='b1'),
B(name='b2', a=a2),
B(name='b3', a=a3)]),
C(name='c2',
b_list=[B(name='b4', a=a1),
B(name='b5'),
B(name='b6', a=a3)]),
C(name='c3',
b_list=[B(name='b7', a=a1),
B(name='b8', a=a2),
B(name='b9')])],
e_list=[E(name='e1'), E(name='e2'), E(name='e3')])
session.add_all([a1, a2, a3, d1])
session.commit()

# A <- B -> C
r1 = tuple(dict(obj) for obj in ConstructQuery({
'a_name': A.name,
'b_names': map_(B.name, A.b_list),
'c_names': map_(get_(C.name, B.c), A.b_list)
}).with_session(session.registry()).order_by(A.name).all())
self.assertEqual(r1, (
{'a_name': 'a1', 'b_names': ['b4', 'b7'], 'c_names': ['c2', 'c3']},
{'a_name': 'a2', 'b_names': ['b2', 'b8'], 'c_names': ['c1', 'c3']},
{'a_name': 'a3', 'b_names': ['b3', 'b6'], 'c_names': ['c1', 'c2']},
))

# B -> C -> D
r2 = tuple(dict(obj) for obj in ConstructQuery({
'b_name': B.name,
'c_name': get_(C.name, B.c),
'd_name': get_(get_(D.name, C.d), B.c),
}).with_session(session.registry()).order_by(B.name).all())
self.assertEqual(r2, (
{'b_name': 'b1', 'c_name': 'c1', 'd_name': 'd1'},
{'b_name': 'b2', 'c_name': 'c1', 'd_name': 'd1'},
{'b_name': 'b3', 'c_name': 'c1', 'd_name': 'd1'},
{'b_name': 'b4', 'c_name': 'c2', 'd_name': 'd1'},
{'b_name': 'b5', 'c_name': 'c2', 'd_name': 'd1'},
{'b_name': 'b6', 'c_name': 'c2', 'd_name': 'd1'},
{'b_name': 'b7', 'c_name': 'c3', 'd_name': 'd1'},
{'b_name': 'b8', 'c_name': 'c3', 'd_name': 'd1'},
{'b_name': 'b9', 'c_name': 'c3', 'd_name': 'd1'},
))

# C -> D <- E
r3 = tuple(dict(obj) for obj in ConstructQuery({
'c_name': C.name,
'd_name': get_(D.name, C.d),
'e_names': get_(map_(E.name, D.e_list), C.d),
}).with_session(session.registry()).order_by(C.name).all())
self.assertEqual(r3, (
{'c_name': 'c1', 'd_name': 'd1', 'e_names': ['e1', 'e2', 'e3']},
{'c_name': 'c2', 'd_name': 'd1', 'e_names': ['e1', 'e2', 'e3']},
{'c_name': 'c3', 'd_name': 'd1', 'e_names': ['e1', 'e2', 'e3']},
))

# D <- C <- B
r4 = dict(ConstructQuery({
'd_name': D.name,
'c_names': map_(C.name, D.c_list),
'b_names': map_(map_(B.name, C.b_list), D.c_list),
}).with_session(session.registry()).order_by(D.name).one())
self.assertEqual(r4['d_name'], 'd1')
self.assertEqual(set(r4['c_names']), {'c1', 'c2', 'c3'})
self.assertEqual(set(map(frozenset, r4['b_names'])), {
frozenset({'b1', 'b2', 'b3'}),
frozenset({'b4', 'b5', 'b6'}),
frozenset({'b7', 'b8', 'b9'}),
})

@unittest.skip('TODO')
def test_with_define(self):

class A(self.base_cls):
name = Column(String)
b_id = Column(Integer, ForeignKey('b.id'))
b = relationship('B')

class B(self.base_cls):
name = Column(String)

@define
def full_name(a, b):
def body(a_name, b_name):
return ' '.join((a_name.capitalize(), b_name.capitalize()))
return body, [a.name, b.name]

session = self.init()
b1, b2, b3 = B(name='b1'), B(name='b2'), B(name='b3')
session.add_all([
A(name='a1', b=b1), A(name='a2', b=b1), A(name='a3', b=b1),
A(name='a4', b=b2), A(name='a5', b=b2), A(name='a6', b=b2),
A(name='a7', b=b3), A(name='a8', b=b3), A(name='a9', b=b3),
])
session.commit()

query = (
ConstructQuery({
# 'full_name': full_name.defn(A, A.b),
'full_name': apply_(full_name.func, args=[A.name, get_(B.name, A.b)]),
})
.with_session(session.registry())
)

self.assertEqual(
tuple(dict(obj) for obj in query.all()),
({'full_name': 'A1 B1'},
{'full_name': 'A2 B1'},
{'full_name': 'A3 B1'},
{'full_name': 'A4 B2'},
{'full_name': 'A5 B2'},
{'full_name': 'A6 B2'},
{'full_name': 'A7 B3'},
{'full_name': 'A8 B3'},
{'full_name': 'A9 B3'}),
)

0 comments on commit 41062e3

Please sign in to comment.