From 627b70390c3ba8bced4c5274830fe86a48095c88 Mon Sep 17 00:00:00 2001 From: Vladimir Magamedov Date: Tue, 15 Apr 2014 18:12:13 +0300 Subject: [PATCH] Fixed bug in relative sub-queries, which were able to produce queries with in_ operator containing empty list. Thanks to brabadu. --- sqlconstruct.py | 10 +++--- tests.py | 87 ++++++++++++++++++++++++++++++------------------- tox.ini | 24 ++++++++++---- 3 files changed, 77 insertions(+), 44 deletions(-) diff --git a/sqlconstruct.py b/sqlconstruct.py index d915f96..5a6dcd4 100644 --- a/sqlconstruct.py +++ b/sqlconstruct.py @@ -323,13 +323,14 @@ def __execute__(self, query_plan, outer_query, outer_rows, session): ext_col_id = query_plan.column_id(outer_query, self.__ext_expr) ext_col_values = [row[ext_col_id] for row in outer_rows] + ext_col_values_set = set(ext_col_values) - {None} - if outer_rows: + if ext_col_values_set: rows = ( self .with_session(session) .with_entities(*chain(columns, (self.__int_expr,))) - .filter(self.__int_expr.in_(set(ext_col_values) - {None})) + .filter(self.__int_expr.in_(ext_col_values_set)) .all() ) else: @@ -389,13 +390,14 @@ def __execute__(self, query_plan, outer_query, outer_rows, session): ext_col_id = query_plan.column_id(outer_query, self.__ext_expr) ext_col_values = [row[ext_col_id] for row in outer_rows] + ext_col_values_set = set(ext_col_values) - {None} - if ext_col_values: + if ext_col_values_set: rows = ( self .with_session(session) .with_entities(*chain(columns, (self.__int_expr,))) - .filter(self.__int_expr.in_(set(ext_col_values) - {None})) + .filter(self.__int_expr.in_(ext_col_values_set)) .all() ) else: diff --git a/tests.py b/tests.py index b733c70..3996b05 100644 --- a/tests.py +++ b/tests.py @@ -16,11 +16,14 @@ except ImportError: import unittest +from mock import patch + import sqlalchemy from sqlalchemy import Table, Column, String, Integer, ForeignKey from sqlalchemy import create_engine, func from sqlalchemy.orm import Session, Query as QueryBase, relationship, aliased from sqlalchemy.orm import subqueryload, scoped_session, sessionmaker +from sqlalchemy.sql.operators import ColumnOperators from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta @@ -43,6 +46,7 @@ from sqlconstruct import Construct, Object, apply_, if_, define, QueryMixin from sqlconstruct import ConstructQuery, map_, get_, _Scope, _QueryPlan from sqlconstruct import ObjectSubQuery, CollectionSubQuery +from sqlconstruct import RelativeObjectSubQuery, RelativeCollectionSubQuery if SQLA_ge_09: @@ -524,8 +528,8 @@ def test_query_with_explicit_join(self): self.assertEqual(s2.b_id, 2) self.assertEqual(s2.b_name, 'b2') - @unittest.skipIf(SQLA_ge_08, 'SQLAlchemy < 0.8') - def test_query_with_implicit_join_lt_08(self): + @unittest.skipIf(SQLA_ge_09, 'SQLAlchemy < 0.9') + def test_query_with_implicit_join_lt_09(self): from sqlalchemy.exc import InvalidRequestError with self.assertRaises(InvalidRequestError) as e1: @@ -538,8 +542,11 @@ def test_query_with_implicit_join_lt_08(self): ) .join(self.a_cls) ) - self.assertEqual(e1.exception.args[0], - 'Could not find a FROM clause to join from') + if SQLA_ge_08: + self.assertIn("Don't know how to join", e1.exception.args[0]) + else: + self.assertEqual(e1.exception.args[0], + "Could not find a FROM clause to join from") with self.assertRaises(InvalidRequestError) as e2: ( @@ -551,36 +558,11 @@ def test_query_with_implicit_join_lt_08(self): ) .join(self.b_cls) ) - self.assertEqual(e2.exception.args[0], - 'Could not find a FROM clause to join from') - - @unittest.skipIf(not SQLA_ge_08 or SQLA_ge_09, '0.8 <= SQLAlchemy < 0.9') - def test_query_with_implicit_join_ge_08(self): - from sqlalchemy.exc import NoInspectionAvailable - - with self.assertRaises(NoInspectionAvailable) as e1: - ( - self.session.query( - Construct({'a_id': self.a_cls.id, - 'a_name': self.a_cls.name, - 'b_id': self.b_cls.id, - 'b_name': self.b_cls.name}), - ) - .join(self.a_cls) - ) - self.assertIn('No inspection system is available', e1.exception.args[0]) - - with self.assertRaises(NoInspectionAvailable) as e2: - ( - self.session.query( - Construct({'a_id': self.a_cls.id, - 'a_name': self.a_cls.name, - 'b_id': self.b_cls.id, - 'b_name': self.b_cls.name}), - ) - .join(self.b_cls) - ) - self.assertIn('No inspection system is available', e2.exception.args[0]) + if SQLA_ge_08: + self.assertIn("Don't know how to join", e2.exception.args[0]) + else: + self.assertEqual(e2.exception.args[0], + "Could not find a FROM clause to join from") @unittest.skip('optional') def test_performance(self): @@ -858,6 +840,43 @@ class B(self.base_cls): ), ) + def test_non_empty_in_op_in_relative_subqueries(self): + in_op = ColumnOperators.in_ + + class EmptyInOpError(Exception): + pass + + def wrapper(self, values): + if not values: + raise EmptyInOpError + return in_op(self, values) + + patcher = patch.object(ColumnOperators, 'in_', wrapper) + + class A(self.base_cls): + value = Column(String) + + class B(self.base_cls): + value = Column(String) + + session = self.init() + session.add_all([A(), A(), A()]) + session.commit() + + obj_sq = RelativeObjectSubQuery(A.value, B.value) + obj_query = ConstructQuery({'a_id': A.id, + 'b_id': get_(B.id, obj_sq)}, + session) + with patcher: + obj_query.all() + + list_sq = RelativeCollectionSubQuery(A.value, B.value) + list_query = ConstructQuery({'a_id': A.id, + 'b_id': map_(B.id, list_sq)}, + session) + with patcher: + list_query.all() + def test_nested(self): """ A <- B -> C -> D <- E diff --git a/tox.ini b/tox.ini index 430dc84..97a3e40 100644 --- a/tox.ini +++ b/tox.ini @@ -6,25 +6,29 @@ commands = basepython = python2.7 deps = nose + mock sqlalchemy==0.7.10 [testenv:py27sqla08] basepython = python2.7 deps = nose - sqlalchemy==0.8.4 + mock + sqlalchemy==0.8.6 [testenv:py27sqla09] basepython = python2.7 deps = nose - sqlalchemy==0.9.2 + mock + sqlalchemy==0.9.4 [testenv:py27sqla0X] basepython = python2.7 usedevelop = True deps = nose + mock -egit+git@github.com:zzzeek/sqlalchemy.git@master#egg=sqlalchemy @@ -32,25 +36,29 @@ deps = basepython = python3.3 deps = nose + mock sqlalchemy==0.7.10 [testenv:py33sqla08] basepython = python3.3 deps = nose - sqlalchemy==0.8.4 + mock + sqlalchemy==0.8.6 [testenv:py33sqla09] basepython = python3.3 deps = nose - sqlalchemy==0.9.2 + mock + sqlalchemy==0.9.4 [testenv:py33sqla0X] basepython = python3.3 usedevelop = True deps = nose + mock -egit+git@github.com:zzzeek/sqlalchemy.git@master#egg=sqlalchemy @@ -58,23 +66,27 @@ deps = basepython = pypy deps = nose + mock sqlalchemy==0.7.10 [testenv:pypysqla08] basepython = pypy deps = nose - sqlalchemy==0.8.4 + mock + sqlalchemy==0.8.6 [testenv:pypysqla09] basepython = pypy deps = nose - sqlalchemy==0.9.2 + mock + sqlalchemy==0.9.4 [testenv:pypysqla0X] basepython = pypy usedevelop = True deps = nose + mock -egit+git@github.com:zzzeek/sqlalchemy.git@master#egg=sqlalchemy