From 83ebbb4e791a147d3c67262cf520219e309e5f85 Mon Sep 17 00:00:00 2001 From: Vladimir Magamedov Date: Tue, 18 Feb 2014 18:01:42 +0200 Subject: [PATCH] Base subquery class subclassed from sqlalchemy.orm.query:Query class. Refactored subquery classes. --- sqlconstruct.py | 104 ++++++++++++++++++++++-------------------------- 1 file changed, 47 insertions(+), 57 deletions(-) diff --git a/sqlconstruct.py b/sqlconstruct.py index f41019b..3089b49 100644 --- a/sqlconstruct.py +++ b/sqlconstruct.py @@ -116,7 +116,7 @@ def process_rows(self, rows, session): children = self.query_children(None) _results = [[((self.query_id(None), row),) for row in rows]] - _results.extend(child.process(self, None, rows, session) + _results.extend(child.__process__(self, None, rows, session) for child in children) return {0: [tuple(chain(*r)) for r in zip(*_results)]} @@ -167,16 +167,25 @@ def loop(result, _query_id=self.query_plan.query_id(self.query)): return loop -class _QueryBase(object): +class _QueryBase(_SAQuery): + __hash = None def __hash__(self): - raise NotImplementedError + if self.__hash is not None: + id_, hash_ = self.__hash + if id(self) == id_: + return hash_ + return super(_QueryBase, self).__hash__() + + def __set_hash__(self, hash): + self.__hash = (id(self), hash) def __eq__(self, other): return hash(self) == hash(other) def __contains_column__(self, column): - raise NotImplementedError + return any(el.c.contains_column(column) + for el in self.statement.froms) def __reference__(self): return None @@ -184,7 +193,7 @@ def __reference__(self): def __requires__(self): return tuple() - def process(self, query_plan, outer_query, outer_rows, session): + def __process__(self, query_plan, outer_query, outer_rows, session): raise NotImplementedError @@ -198,42 +207,34 @@ class _CollectionSubQuery(_QueryBase): class _RelativeObjectSubQuery(_QueryBase): - def __init__(self, ext_expr, int_expr, query, _hash=None): - self._ext_expr = ext_expr - self._int_expr = int_expr - self._sa_query = query - self._hash = _hash or hash((type(self), ext_expr, int_expr, query)) + def __init__(self, ext_expr, int_expr): + self.__ext_expr = ext_expr + self.__int_expr = int_expr + super(_RelativeObjectSubQuery, self).__init__([int_expr]) @classmethod def from_relation(cls, relation_property): ext_expr, int_expr = relation_property.local_remote_pairs[0] - query = _SAQuery([relation_property.mapper.class_]) - hash_ = hash((cls, relation_property)) - return cls(ext_expr, int_expr, query, hash_) + query = cls(ext_expr, int_expr) + query.__set_hash__(hash((cls, relation_property))) + return query def __reference__(self): - return self._ext_expr + return self.__ext_expr def __requires__(self): - return [self._int_expr] - - def __hash__(self): - return self._hash - - def __contains_column__(self, column): - return any(el.c.contains_column(column) - for el in self._sa_query.statement.froms) + return [self.__int_expr] - def process(self, query_plan, outer_query, outer_rows, session): - ext_col_id = query_plan.column_id(outer_query, self._ext_expr) + def __process__(self, query_plan, outer_query, outer_rows, session): + ext_col_id = query_plan.column_id(outer_query, self.__ext_expr) ext_exprs = [row[ext_col_id] for row in outer_rows] if ext_exprs: - columns = query_plan.query_columns(self) + (self._int_expr,) + columns = query_plan.query_columns(self) + (self.__int_expr,) rows = ( - self._sa_query + self .with_session(session) .with_entities(*columns) - .filter(self._int_expr.in_(set(ext_exprs))) + .filter(self.__int_expr.in_(set(ext_exprs))) .all() ) else: @@ -242,7 +243,7 @@ def process(self, query_plan, outer_query, outer_rows, session): children = query_plan.query_children(self) _results = [[((query_plan.query_id(self), row),) for row in rows]] - _results.extend(child.process(query_plan, self, rows, session) + _results.extend(child.__process__(query_plan, self, rows, session) for child in children) mapping = {} @@ -261,48 +262,37 @@ def process(self, query_plan, outer_query, outer_rows, session): class _RelativeCollectionSubQuery(_QueryBase): - def __init__(self, ext_expr, int_expr, query, _hash=None): - self._ext_expr = ext_expr - self._int_expr = int_expr - self._sa_query = query - self._hash = _hash or hash((type(self), ext_expr, int_expr, query)) + def __init__(self, ext_expr, int_expr): + self.__ext_expr = ext_expr + self.__int_expr = int_expr + super(_RelativeCollectionSubQuery, self).__init__([int_expr]) @classmethod def from_relation(cls, relation_property): + ext_expr, int_expr = relation_property.local_remote_pairs[0] + query = cls(ext_expr, int_expr) if relation_property.secondary is not None: - ext_expr, int_expr = relation_property.local_remote_pairs[0] - query = (_SAQuery([relation_property.mapper.class_]) - .join(relation_property.secondary, - relation_property.secondaryjoin)) - else: - ext_expr, int_expr = relation_property.local_remote_pairs[0] - query = _SAQuery([relation_property.mapper.class_]) - hash_ = hash((cls, relation_property)) - return cls(ext_expr, int_expr, query, hash_) + query = query.join(relation_property.mapper.class_, + relation_property.secondaryjoin) + query.__set_hash__(hash((cls, relation_property))) + return query def __reference__(self): - return self._ext_expr + return self.__ext_expr def __requires__(self): - return [self._int_expr] - - def __hash__(self): - return self._hash - - def __contains_column__(self, column): - return any(el.c.contains_column(column) - for el in self._sa_query.statement.froms) + return [self.__int_expr] - def process(self, query_plan, outer_query, outer_rows, session): - ext_col_id = query_plan.column_id(outer_query, self._ext_expr) + def __process__(self, query_plan, outer_query, outer_rows, session): + ext_col_id = query_plan.column_id(outer_query, self.__ext_expr) ext_exprs = [row[ext_col_id] for row in outer_rows] if ext_exprs: - columns = query_plan.query_columns(self) + (self._int_expr,) + columns = query_plan.query_columns(self) + (self.__int_expr,) rows = ( - self._sa_query + self .with_session(session) .with_entities(*columns) - .filter(self._int_expr.in_(ext_exprs)) + .filter(self.__int_expr.in_(ext_exprs)) .all() ) else: @@ -311,7 +301,7 @@ def process(self, query_plan, outer_query, outer_rows, session): children = query_plan.query_children(self) _results = [[((query_plan.query_id(self), row),) for row in rows]] - _results.extend(child.process(query_plan, self, rows, session) + _results.extend(child.__process__(query_plan, self, rows, session) for child in children) groups = defaultdict(list)