From 0f6bd58975d884a2d562be445d96f6902ec018d3 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Tue, 10 Oct 2023 09:48:53 -0700 Subject: [PATCH] review feedbacks --- src/snowflake/sqlalchemy/base.py | 17 +++++++---------- tests/test_orm.py | 31 +++++++++++++++++++++---------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 063651fa..f4afc49d 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -124,7 +124,7 @@ def _setup_joins(self, args, raw_columns): raw_columns, left, right, onclause ) else: - (replace_from_obj_index) = self._join_place_explicit_left_side(left) + replace_from_obj_index = self._join_place_explicit_left_side(left) if replace_from_obj_index is not None: # splice into an existing element in the @@ -133,19 +133,16 @@ def _setup_joins(self, args, raw_columns): self.from_clauses = ( self.from_clauses[:replace_from_obj_index] - + ( - _Snowflake_Selectable_Join( # handle Snowflake BCR bcr-1057 - left_clause, - right, - onclause, - isouter=isouter, - full=full, - ), + + _Snowflake_Selectable_Join( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=isouter, + full=full, ) + self.from_clauses[replace_from_obj_index + 1 :] ) else: - self.from_clauses = self.from_clauses + ( # handle Snowflake BCR bcr-1057 _Snowflake_Selectable_Join( diff --git a/tests/test_orm.py b/tests/test_orm.py index 9ea025d9..e485d737 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -339,8 +339,7 @@ class User(Base): Base.metadata.drop_all(con) -def test_outer_lateral_join(engine_testaccount): - pytest.skip("test case needs to be fixed") +def test_outer_lateral_join(engine_testaccount, caplog): Base = declarative_base() class Employee(Base): @@ -348,7 +347,6 @@ class Employee(Base): employee_id = Column(Integer, primary_key=True) last_name = Column(String) - department_id = Column(Integer, ForeignKey("departments.department_id")) class Department(Base): __tablename__ = "departments" @@ -358,17 +356,30 @@ class Department(Base): Base.metadata.create_all(engine_testaccount) session = Session(bind=engine_testaccount) - e1 = Employee(employee_id=101, last_name="Richards", department_id=1) - e2 = Employee(employee_id=102, last_name="Paulson", department_id=1) - e3 = Employee(employee_id=103, last_name="Johnson", department_id=2) + e1 = Employee(employee_id=101, last_name="Richards") d1 = Department(department_id=1, name="Engineering") - d2 = Department(department_id=2, name="Support") - session.add_all([e1, e2, e3, d1, d2]) + session.add_all([e1, d1]) session.commit() sub = select(Department).lateral() - query = select(Employee.employee_id).select_from(Employee).outerjoin(sub) - session.execute(query) + query = ( + select(Employee.employee_id, Department.department_id) + .select_from(Employee) + .outerjoin(sub) + ) + assert ( + str(query.compile(engine_testaccount)).replace("\n", "") + == "SELECT employees.employee_id, departments.department_id " + "FROM departments, employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1" + ) + with caplog.at_level(logging.DEBUG): + assert [res for res in session.execute(query)] + assert ( + "SELECT employees.employee_id, departments.department_id FROM departments" + in caplog.text + ) def test_lateral_join_without_condition(engine_testaccount, caplog):