Skip to content

Commit

Permalink
Added ability to explicitly bind any expression to the particular que…
Browse files Browse the repository at this point in the history
…ry to avoid default lookup procedure. Refactored "from_relation" methods to also accept InstrumentedAttribute instances and to raise TypeError when invalid argument type was provided.
  • Loading branch information
vmagamedov committed Apr 22, 2014
1 parent 627b703 commit 2c5e0a0
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 69 deletions.
70 changes: 58 additions & 12 deletions sqlconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, name, *exprs, **kw):
'ConstructQuery', 'construct_query_maker', 'Construct',
'ObjectSubQuery', 'CollectionSubQuery',
'RelativeObjectSubQuery', 'RelativeCollectionSubQuery',
'if_', 'apply_', 'map_', 'get_', 'define',
'bind', 'if_', 'apply_', 'map_', 'get_', 'define',
'QueryMixin',
)

Expand Down Expand Up @@ -146,6 +146,21 @@ def process_rows(self, rows, session):
return {0: [tuple(chain(*r)) for r in zip(*results)]}


class _BoundExpression(object):

def __init__(self, expr, subquery):
self.expr = expr
self.subquery = subquery

def __getattr__(self, name):
# TODO: return _BoundExpression(getattr(self.expr, name), self.subquery)
raise NotImplementedError


def bind(expr, subquery):
return _BoundExpression(expr, subquery)


class _Scope(object):

def __init__(self, query_plan, query=None, parent=None):
Expand All @@ -166,19 +181,30 @@ def lookup(self, column):
scope = scope.parent
return scope

def root(self):
def query_scope(self, query):
scope = self
while scope.query:
if scope.query is query:
return scope
scope = scope.parent
raise ValueError('Unknown query {0!r}'.format(query))

def root_scope(self):
scope = self
while scope.query:
scope = scope.parent
return scope

def nested(self, query):
ext_expr = query.__reference__()
if ext_expr is not None:
scope = self.lookup(ext_expr)
self.query_plan.add_expr(scope.query, ext_expr)
reference = query.__reference__()
if isinstance(reference, _BoundExpression):
scope = self.query_scope(reference.subquery)
self.query_plan.add_expr(reference.subquery, reference.expr)
elif reference is not None:
scope = self.lookup(reference)
self.query_plan.add_expr(scope.query, reference)
else:
scope = self.root()
scope = self.root_scope()
return type(self)(self.query_plan, query, scope)

def add(self, column, query=None):
Expand Down Expand Up @@ -302,9 +328,20 @@ def __init__(self, ext_expr, int_expr, scoped_session=None):
super(RelativeObjectSubQuery, self).__init__([int_expr])

@classmethod
def from_relation(cls, relation_property):
def from_relation(cls, relation):
if isinstance(relation, InstrumentedAttribute):
relation_property = relation.property
else:
relation_property = relation

if not isinstance(relation_property, RelationshipProperty):
raise TypeError('Invalid type provided: {0!r}'.format(relation))

ext_expr, int_expr = relation_property.local_remote_pairs[0]
query = cls(ext_expr, int_expr)
if relation_property.secondary is not None:
query = query.join(relation_property.mapper.class_,
relation_property.secondaryjoin)
query.__set_hash__(hash((cls, relation_property)))
return query

Expand Down Expand Up @@ -366,7 +403,15 @@ def __init__(self, ext_expr, int_expr, scoped_session=None):
super(RelativeCollectionSubQuery, self).__init__([int_expr])

@classmethod
def from_relation(cls, relation_property):
def from_relation(cls, relation):
if isinstance(relation, InstrumentedAttribute):
relation_property = relation.property
else:
relation_property = relation

if not isinstance(relation_property, RelationshipProperty):
raise TypeError('Invalid type provided: {0!r}'.format(relation))

ext_expr, int_expr = relation_property.local_remote_pairs[0]
query = cls(ext_expr, int_expr)
if relation_property.secondary is not None:
Expand Down Expand Up @@ -512,6 +557,8 @@ def __or__(self, value):
def _get_value_processor(scope, value):
if isinstance(value, ColumnElement):
return scope.add(value)
elif isinstance(value, _BoundExpression):
return scope.add(value.expr, value.subquery)
elif isinstance(value, QueryableAttribute):
return _get_value_processor(scope, value.__clause_element__())
elif isinstance(value, _Processable):
Expand Down Expand Up @@ -574,8 +621,7 @@ def __init__(self, func, collection):
if isinstance(collection, (CollectionSubQuery, RelativeCollectionSubQuery)):
sub_query = collection
else:
sub_query = (RelativeCollectionSubQuery
.from_relation(collection.property))
sub_query = RelativeCollectionSubQuery.from_relation(collection)
self._func = func
self._sub_query = sub_query

Expand All @@ -594,7 +640,7 @@ def __init__(self, func, obj):
if isinstance(obj, (ObjectSubQuery, RelativeObjectSubQuery)):
sub_query = obj
else:
sub_query = RelativeObjectSubQuery.from_relation(obj.property)
sub_query = RelativeObjectSubQuery.from_relation(obj)
self._func = func
self._sub_query = sub_query

Expand Down
70 changes: 55 additions & 15 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
except ImportError:
import unittest

from mock import patch
try:
from unittest.mock import patch
except ImportError:
from mock import patch

import sqlalchemy
from sqlalchemy import Table, Column, String, Integer, ForeignKey
Expand Down Expand Up @@ -44,7 +47,7 @@


from sqlconstruct import Construct, Object, apply_, if_, define, QueryMixin
from sqlconstruct import ConstructQuery, map_, get_, _Scope, _QueryPlan
from sqlconstruct import ConstructQuery, bind, map_, get_, _Scope, _QueryPlan
from sqlconstruct import ObjectSubQuery, CollectionSubQuery
from sqlconstruct import RelativeObjectSubQuery, RelativeCollectionSubQuery

Expand Down Expand Up @@ -183,7 +186,7 @@ def test_basic_construct(self):
'a_id': self.a_cls.id,
'a_name': self.a_cls.name,
})
self.assertEquals(set(struct._columns), {
self.assertEqual(set(struct._columns), {
self.a_cls.__table__.c.id,
self.a_cls.__table__.c.name,
})
Expand All @@ -201,7 +204,7 @@ def test_nested_construct(self):
'a_id': apply_(operator.add, [self.a_cls.id, 5]),
'a_name': apply_(operator.concat, [self.a_cls.name, '-test']),
})
self.assertEquals(set(struct._columns), {
self.assertEqual(set(struct._columns), {
self.a_cls.__table__.c.id,
self.a_cls.__table__.c.name,
})
Expand Down Expand Up @@ -248,15 +251,15 @@ def test_apply_with_columns(self):
add = lambda a, b: a + b

apl1 = apply_(add, [f1], {'b': f2})
self.assertEquals(columns_set(apl1), {c1, c2})
self.assertEqual(columns_set(apl1), {c1, c2})
self.assertEqual(proceed(apl1, {c1: 3, c2: 4}), 3 + 4)

apl2 = apply_(add, [c1], {'b': c2})
self.assertEquals(columns_set(apl2), {c1, c2})
self.assertEqual(columns_set(apl2), {c1, c2})
self.assertEqual(proceed(apl1, {c1: 4, c2: 5}), 4 + 5)

apl3 = apply_(add, [fn1], {'b': fn2})
self.assertEquals(columns_set(apl3), {fn1, fn2})
self.assertEqual(columns_set(apl3), {fn1, fn2})
self.assertEqual(proceed(apl3, {fn1: 5, fn2: 6}), 5 + 6)

def test_nested_apply(self):
Expand Down Expand Up @@ -293,7 +296,7 @@ def test_nested_apply(self):
]),
]),
])
self.assertEquals(columns_set(apl), {c1, c2})
self.assertEqual(columns_set(apl), {c1, c2})
self.assertEqual(proceed(apl, {c1: 4, c2: 5}), sum(range(10)))

def test_if(self):
Expand All @@ -304,20 +307,20 @@ def test_if(self):
c4 = self.b_cls.__table__.c.name

if1 = if_(True, then_=1, else_=2)
self.assertEquals(columns_set(if1), set())
self.assertEqual(columns_set(if1), set())
self.assertEqual(proceed(if1, {}), 1)

if2 = if_(False, then_=1, else_=2)
self.assertEquals(columns_set(if2), set())
self.assertEqual(columns_set(if2), set())
self.assertEqual(proceed(if2, {}), 2)

if3 = if_(c1, then_=c2, else_=c3)
self.assertEquals(columns_set(if3), {c1, c2, c3})
self.assertEqual(columns_set(if3), {c1, c2, c3})
self.assertEqual(proceed(if3, {c1: 0, c2: 3, c3: 6}), 6)
self.assertEqual(proceed(if3, {c1: 1, c2: 3, c3: 6}), 3)

if4 = if_(c1, then_=apply_(add, [c2, c3]), else_=apply_(add, [c3, c4]))
self.assertEquals(columns_set(if4), {c1, c2, c3, c4})
self.assertEqual(columns_set(if4), {c1, c2, c3, c4})
self.assertEqual(proceed(if4, {c1: 0, c2: 2, c3: 3, c4: 4}), 3 + 4)
self.assertEqual(proceed(if4, {c1: 1, c2: 2, c3: 3, c4: 4}), 2 + 3)

Expand Down Expand Up @@ -358,7 +361,7 @@ def test_defined_calls(self):
apl1 = defined_func.defn(self.a_cls, self.b_cls,
extra_id=3, extra_name='baz')
self.assertTrue(isinstance(apl1, apply_), type(apl1))
self.assertEquals(columns_set(apl1), {c1, c2, c3, c4})
self.assertEqual(columns_set(apl1), {c1, c2, c3, c4})
self.assertEqual(
proceed(apl1, {c1: 1, c2: 'foo', c3: 2, c4: 'bar'}),
(1 + 2 + 3, 'foo' + 'bar' + 'baz'),
Expand All @@ -367,7 +370,7 @@ def test_defined_calls(self):
apl2 = defined_func.defn(self.a_cls, self.b_cls,
extra_id=c1, extra_name=c2)
self.assertTrue(isinstance(apl2, apply_), type(apl2))
self.assertEquals(columns_set(apl2), {c1, c2, c3, c4})
self.assertEqual(columns_set(apl2), {c1, c2, c3, c4})
self.assertEqual(
proceed(apl2, {c1: 1, c2: 'foo', c3: 2, c4: 'bar'}),
(1 + 2 + 1, 'foo' + 'bar' + 'foo'),
Expand All @@ -377,7 +380,7 @@ def test_defined_calls(self):
extra_id=apply_(operator.add, [c1, c3]),
extra_name=apply_(operator.concat, [c2, c4]))
self.assertTrue(isinstance(apl3, apply_), type(apl3))
self.assertEquals(columns_set(apl3), {c1, c2, c3, c4})
self.assertEqual(columns_set(apl3), {c1, c2, c3, c4})
self.assertEqual(
proceed(apl3, {c1: 1, c2: 'foo', c3: 2, c4: 'bar'}),
(1 + 2 + (1 + 2), 'foo' + 'bar' + ('foo' + 'bar')),
Expand Down Expand Up @@ -1094,6 +1097,43 @@ class A(self.base_cls):
with self.assertRaises(NotImplementedError):
ConstructQuery({'id': A.id}).add_entity(A)

def test_bound_to_query_expressions(self):

class A(self.base_cls):
name = Column(String)
b_list = relationship('B')

class B(self.base_cls):
name = Column(String)
a_id = Column(Integer, ForeignKey('a.id'))

session = self.init()
session.add_all([
A(name='a1', b_list=[B(name='b1')]),
A(name='a2', b_list=[B(name='b4'), B(name='b5')]),
A(name='a3', b_list=[B(name='b7'), B(name='b8'), B(name='b9')]),
])
session.commit()

sq = (
RelativeObjectSubQuery.from_relation(A.b_list)
.group_by(B.a_id)
)
query = (
ConstructQuery({
'a_name': A.name,
'b_count': get_(bind(func.count(), sq), sq),
})
.order_by(A.name.asc())
.with_session(session.registry())
)
self.assertEqual(
[dict(obj) for obj in query.all()],
[{'a_name': 'a1', 'b_count': 1},
{'a_name': 'a2', 'b_count': 2},
{'a_name': 'a3', 'b_count': 3}],
)

@unittest.skip('optional')
def test_performance(self):

Expand Down
Loading

0 comments on commit 2c5e0a0

Please sign in to comment.