Skip to content

Commit

Permalink
Fixed bug in relative sub-queries, which were able to produce queries…
Browse files Browse the repository at this point in the history
… with in_ operator containing empty list. Thanks to brabadu.
  • Loading branch information
vmagamedov committed Apr 15, 2014
1 parent a54b9bc commit 627b703
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 44 deletions.
10 changes: 6 additions & 4 deletions sqlconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,14 @@ def __execute__(self, query_plan, outer_query, outer_rows, session):

ext_col_id = query_plan.column_id(outer_query, self.__ext_expr)
ext_col_values = [row[ext_col_id] for row in outer_rows]
ext_col_values_set = set(ext_col_values) - {None}

if outer_rows:
if ext_col_values_set:
rows = (
self
.with_session(session)
.with_entities(*chain(columns, (self.__int_expr,)))
.filter(self.__int_expr.in_(set(ext_col_values) - {None}))
.filter(self.__int_expr.in_(ext_col_values_set))
.all()
)
else:
Expand Down Expand Up @@ -389,13 +390,14 @@ def __execute__(self, query_plan, outer_query, outer_rows, session):

ext_col_id = query_plan.column_id(outer_query, self.__ext_expr)
ext_col_values = [row[ext_col_id] for row in outer_rows]
ext_col_values_set = set(ext_col_values) - {None}

if ext_col_values:
if ext_col_values_set:
rows = (
self
.with_session(session)
.with_entities(*chain(columns, (self.__int_expr,)))
.filter(self.__int_expr.in_(set(ext_col_values) - {None}))
.filter(self.__int_expr.in_(ext_col_values_set))
.all()
)
else:
Expand Down
87 changes: 53 additions & 34 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
except ImportError:
import unittest

from mock import patch

import sqlalchemy
from sqlalchemy import Table, Column, String, Integer, ForeignKey
from sqlalchemy import create_engine, func
from sqlalchemy.orm import Session, Query as QueryBase, relationship, aliased
from sqlalchemy.orm import subqueryload, scoped_session, sessionmaker
from sqlalchemy.sql.operators import ColumnOperators
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta


Expand All @@ -43,6 +46,7 @@
from sqlconstruct import Construct, Object, apply_, if_, define, QueryMixin
from sqlconstruct import ConstructQuery, map_, get_, _Scope, _QueryPlan
from sqlconstruct import ObjectSubQuery, CollectionSubQuery
from sqlconstruct import RelativeObjectSubQuery, RelativeCollectionSubQuery


if SQLA_ge_09:
Expand Down Expand Up @@ -524,8 +528,8 @@ def test_query_with_explicit_join(self):
self.assertEqual(s2.b_id, 2)
self.assertEqual(s2.b_name, 'b2')

@unittest.skipIf(SQLA_ge_08, 'SQLAlchemy < 0.8')
def test_query_with_implicit_join_lt_08(self):
@unittest.skipIf(SQLA_ge_09, 'SQLAlchemy < 0.9')
def test_query_with_implicit_join_lt_09(self):
from sqlalchemy.exc import InvalidRequestError

with self.assertRaises(InvalidRequestError) as e1:
Expand All @@ -538,8 +542,11 @@ def test_query_with_implicit_join_lt_08(self):
)
.join(self.a_cls)
)
self.assertEqual(e1.exception.args[0],
'Could not find a FROM clause to join from')
if SQLA_ge_08:
self.assertIn("Don't know how to join", e1.exception.args[0])
else:
self.assertEqual(e1.exception.args[0],
"Could not find a FROM clause to join from")

with self.assertRaises(InvalidRequestError) as e2:
(
Expand All @@ -551,36 +558,11 @@ def test_query_with_implicit_join_lt_08(self):
)
.join(self.b_cls)
)
self.assertEqual(e2.exception.args[0],
'Could not find a FROM clause to join from')

@unittest.skipIf(not SQLA_ge_08 or SQLA_ge_09, '0.8 <= SQLAlchemy < 0.9')
def test_query_with_implicit_join_ge_08(self):
from sqlalchemy.exc import NoInspectionAvailable

with self.assertRaises(NoInspectionAvailable) as e1:
(
self.session.query(
Construct({'a_id': self.a_cls.id,
'a_name': self.a_cls.name,
'b_id': self.b_cls.id,
'b_name': self.b_cls.name}),
)
.join(self.a_cls)
)
self.assertIn('No inspection system is available', e1.exception.args[0])

with self.assertRaises(NoInspectionAvailable) as e2:
(
self.session.query(
Construct({'a_id': self.a_cls.id,
'a_name': self.a_cls.name,
'b_id': self.b_cls.id,
'b_name': self.b_cls.name}),
)
.join(self.b_cls)
)
self.assertIn('No inspection system is available', e2.exception.args[0])
if SQLA_ge_08:
self.assertIn("Don't know how to join", e2.exception.args[0])
else:
self.assertEqual(e2.exception.args[0],
"Could not find a FROM clause to join from")

@unittest.skip('optional')
def test_performance(self):
Expand Down Expand Up @@ -858,6 +840,43 @@ class B(self.base_cls):
),
)

def test_non_empty_in_op_in_relative_subqueries(self):
in_op = ColumnOperators.in_

class EmptyInOpError(Exception):
pass

def wrapper(self, values):
if not values:
raise EmptyInOpError
return in_op(self, values)

patcher = patch.object(ColumnOperators, 'in_', wrapper)

class A(self.base_cls):
value = Column(String)

class B(self.base_cls):
value = Column(String)

session = self.init()
session.add_all([A(), A(), A()])
session.commit()

obj_sq = RelativeObjectSubQuery(A.value, B.value)
obj_query = ConstructQuery({'a_id': A.id,
'b_id': get_(B.id, obj_sq)},
session)
with patcher:
obj_query.all()

list_sq = RelativeCollectionSubQuery(A.value, B.value)
list_query = ConstructQuery({'a_id': A.id,
'b_id': map_(B.id, list_sq)},
session)
with patcher:
list_query.all()

def test_nested(self):
"""
A <- B -> C -> D <- E
Expand Down
24 changes: 18 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,75 +6,87 @@ commands =
basepython = python2.7
deps =
nose
mock
sqlalchemy==0.7.10

[testenv:py27sqla08]
basepython = python2.7
deps =
nose
sqlalchemy==0.8.4
mock
sqlalchemy==0.8.6

[testenv:py27sqla09]
basepython = python2.7
deps =
nose
sqlalchemy==0.9.2
mock
sqlalchemy==0.9.4

[testenv:py27sqla0X]
basepython = python2.7
usedevelop = True
deps =
nose
mock
[email protected]:zzzeek/sqlalchemy.git@master#egg=sqlalchemy


[testenv:py33sqla07]
basepython = python3.3
deps =
nose
mock
sqlalchemy==0.7.10

[testenv:py33sqla08]
basepython = python3.3
deps =
nose
sqlalchemy==0.8.4
mock
sqlalchemy==0.8.6

[testenv:py33sqla09]
basepython = python3.3
deps =
nose
sqlalchemy==0.9.2
mock
sqlalchemy==0.9.4

[testenv:py33sqla0X]
basepython = python3.3
usedevelop = True
deps =
nose
mock
[email protected]:zzzeek/sqlalchemy.git@master#egg=sqlalchemy


[testenv:pypysqla07]
basepython = pypy
deps =
nose
mock
sqlalchemy==0.7.10

[testenv:pypysqla08]
basepython = pypy
deps =
nose
sqlalchemy==0.8.4
mock
sqlalchemy==0.8.6

[testenv:pypysqla09]
basepython = pypy
deps =
nose
sqlalchemy==0.9.2
mock
sqlalchemy==0.9.4

[testenv:pypysqla0X]
basepython = pypy
usedevelop = True
deps =
nose
mock
[email protected]:zzzeek/sqlalchemy.git@master#egg=sqlalchemy

0 comments on commit 627b703

Please sign in to comment.