Skip to content

Commit

Permalink
Base subquery class subclassed from sqlalchemy.orm.query:Query class.…
Browse files Browse the repository at this point in the history
… Refactored subquery classes.
  • Loading branch information
vmagamedov committed Feb 18, 2014
1 parent 23b3504 commit 83ebbb4
Showing 1 changed file with 47 additions and 57 deletions.
104 changes: 47 additions & 57 deletions sqlconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def process_rows(self, rows, session):
children = self.query_children(None)

_results = [[((self.query_id(None), row),) for row in rows]]
_results.extend(child.process(self, None, rows, session)
_results.extend(child.__process__(self, None, rows, session)
for child in children)

return {0: [tuple(chain(*r)) for r in zip(*_results)]}
Expand Down Expand Up @@ -167,24 +167,33 @@ def loop(result, _query_id=self.query_plan.query_id(self.query)):
return loop


class _QueryBase(object):
class _QueryBase(_SAQuery):
__hash = None

def __hash__(self):
raise NotImplementedError
if self.__hash is not None:
id_, hash_ = self.__hash
if id(self) == id_:
return hash_
return super(_QueryBase, self).__hash__()

def __set_hash__(self, hash):
self.__hash = (id(self), hash)

def __eq__(self, other):
return hash(self) == hash(other)

def __contains_column__(self, column):
raise NotImplementedError
return any(el.c.contains_column(column)
for el in self.statement.froms)

def __reference__(self):
return None

def __requires__(self):
return tuple()

def process(self, query_plan, outer_query, outer_rows, session):
def __process__(self, query_plan, outer_query, outer_rows, session):
raise NotImplementedError


Expand All @@ -198,42 +207,34 @@ class _CollectionSubQuery(_QueryBase):

class _RelativeObjectSubQuery(_QueryBase):

def __init__(self, ext_expr, int_expr, query, _hash=None):
self._ext_expr = ext_expr
self._int_expr = int_expr
self._sa_query = query
self._hash = _hash or hash((type(self), ext_expr, int_expr, query))
def __init__(self, ext_expr, int_expr):
self.__ext_expr = ext_expr
self.__int_expr = int_expr
super(_RelativeObjectSubQuery, self).__init__([int_expr])

@classmethod
def from_relation(cls, relation_property):
ext_expr, int_expr = relation_property.local_remote_pairs[0]
query = _SAQuery([relation_property.mapper.class_])
hash_ = hash((cls, relation_property))
return cls(ext_expr, int_expr, query, hash_)
query = cls(ext_expr, int_expr)
query.__set_hash__(hash((cls, relation_property)))
return query

def __reference__(self):
return self._ext_expr
return self.__ext_expr

def __requires__(self):
return [self._int_expr]

def __hash__(self):
return self._hash

def __contains_column__(self, column):
return any(el.c.contains_column(column)
for el in self._sa_query.statement.froms)
return [self.__int_expr]

def process(self, query_plan, outer_query, outer_rows, session):
ext_col_id = query_plan.column_id(outer_query, self._ext_expr)
def __process__(self, query_plan, outer_query, outer_rows, session):
ext_col_id = query_plan.column_id(outer_query, self.__ext_expr)
ext_exprs = [row[ext_col_id] for row in outer_rows]
if ext_exprs:
columns = query_plan.query_columns(self) + (self._int_expr,)
columns = query_plan.query_columns(self) + (self.__int_expr,)
rows = (
self._sa_query
self
.with_session(session)
.with_entities(*columns)
.filter(self._int_expr.in_(set(ext_exprs)))
.filter(self.__int_expr.in_(set(ext_exprs)))
.all()
)
else:
Expand All @@ -242,7 +243,7 @@ def process(self, query_plan, outer_query, outer_rows, session):
children = query_plan.query_children(self)

_results = [[((query_plan.query_id(self), row),) for row in rows]]
_results.extend(child.process(query_plan, self, rows, session)
_results.extend(child.__process__(query_plan, self, rows, session)
for child in children)

mapping = {}
Expand All @@ -261,48 +262,37 @@ def process(self, query_plan, outer_query, outer_rows, session):

class _RelativeCollectionSubQuery(_QueryBase):

def __init__(self, ext_expr, int_expr, query, _hash=None):
self._ext_expr = ext_expr
self._int_expr = int_expr
self._sa_query = query
self._hash = _hash or hash((type(self), ext_expr, int_expr, query))
def __init__(self, ext_expr, int_expr):
self.__ext_expr = ext_expr
self.__int_expr = int_expr
super(_RelativeCollectionSubQuery, self).__init__([int_expr])

@classmethod
def from_relation(cls, relation_property):
ext_expr, int_expr = relation_property.local_remote_pairs[0]
query = cls(ext_expr, int_expr)
if relation_property.secondary is not None:
ext_expr, int_expr = relation_property.local_remote_pairs[0]
query = (_SAQuery([relation_property.mapper.class_])
.join(relation_property.secondary,
relation_property.secondaryjoin))
else:
ext_expr, int_expr = relation_property.local_remote_pairs[0]
query = _SAQuery([relation_property.mapper.class_])
hash_ = hash((cls, relation_property))
return cls(ext_expr, int_expr, query, hash_)
query = query.join(relation_property.mapper.class_,
relation_property.secondaryjoin)
query.__set_hash__(hash((cls, relation_property)))
return query

def __reference__(self):
return self._ext_expr
return self.__ext_expr

def __requires__(self):
return [self._int_expr]

def __hash__(self):
return self._hash

def __contains_column__(self, column):
return any(el.c.contains_column(column)
for el in self._sa_query.statement.froms)
return [self.__int_expr]

def process(self, query_plan, outer_query, outer_rows, session):
ext_col_id = query_plan.column_id(outer_query, self._ext_expr)
def __process__(self, query_plan, outer_query, outer_rows, session):
ext_col_id = query_plan.column_id(outer_query, self.__ext_expr)
ext_exprs = [row[ext_col_id] for row in outer_rows]
if ext_exprs:
columns = query_plan.query_columns(self) + (self._int_expr,)
columns = query_plan.query_columns(self) + (self.__int_expr,)
rows = (
self._sa_query
self
.with_session(session)
.with_entities(*columns)
.filter(self._int_expr.in_(ext_exprs))
.filter(self.__int_expr.in_(ext_exprs))
.all()
)
else:
Expand All @@ -311,7 +301,7 @@ def process(self, query_plan, outer_query, outer_rows, session):
children = query_plan.query_children(self)

_results = [[((query_plan.query_id(self), row),) for row in rows]]
_results.extend(child.process(query_plan, self, rows, session)
_results.extend(child.__process__(query_plan, self, rows, session)
for child in children)

groups = defaultdict(list)
Expand Down

0 comments on commit 83ebbb4

Please sign in to comment.