Skip to content

Commit

Permalink
Fix support for SQLAlchemy 2
Browse files Browse the repository at this point in the history
  • Loading branch information
pmdevita authored and mekanix committed Mar 23, 2024
1 parent 1ed0d5a commit 84d6bc8
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions ormar/queryset/queries/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def build_select_expression(self) -> sqlalchemy.sql.select:
self.select_from, limit_qry, on_clause
)

expr = sqlalchemy.sql.select(self.columns)
expr = sqlalchemy.sql.select(*self.columns)
expr = expr.select_from(self.select_from)

expr = self._apply_expression_modifiers(expr)
Expand Down Expand Up @@ -191,7 +191,7 @@ def _build_pagination_condition(
elif order.get_field_name_text() == pk_aliased_name:
maxes[pk_aliased_name] = order.get_text_clause()

limit_qry = sqlalchemy.sql.select([qry_text])
limit_qry = sqlalchemy.sql.select(qry_text)
limit_qry = limit_qry.select_from(self.select_from)
limit_qry = FilterQuery(filter_clauses=self.filter_clauses).apply(limit_qry)
limit_qry = FilterQuery(
Expand Down
4 changes: 2 additions & 2 deletions ormar/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ async def values(
exclude_through=exclude_through,
)
column_map = alias_resolver.resolve_columns(
columns_names=list(cast(LegacyRow, rows[0]).keys())
columns_names=rows[0].keys()
)
result = [
{column_map.get(k): v for k, v in dict(x).items() if k in column_map}
Expand Down Expand Up @@ -724,7 +724,7 @@ async def _query_aggr_function(self, func_name: str, columns: List) -> Any:
)
select_columns = [x.apply_func(func, use_label=True) for x in select_actions]
expr = self.build_select_expression().alias(f"subquery_for_{func_name}")
expr = sqlalchemy.select(select_columns).select_from(expr)
expr = sqlalchemy.select(*select_columns).select_from(expr)
# print("\n", expr.compile(compile_kwargs={"literal_binds": True}))
result = await self.database.fetch_one(expr)
return dict(result) if len(result) > 1 else result[0] # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model_definition/test_fields_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def test_combining_groups_together():
category_prefix = group._nested_groups[1]._nested_groups[1].actions[0].table_prefix
assert group_str == (
f"((product.name LIKE '%Test%') "
f"OR (({price_list_prefix}_price_lists.name LIKE 'Aa%') "
f"OR ({category_prefix}_categories.name IN ('Toys', 'Books'))))"
f"OR ({price_list_prefix}_price_lists.name LIKE 'Aa%') "
f"OR ({category_prefix}_categories.name IN ('Toys', 'Books')))"
)


Expand Down

0 comments on commit 84d6bc8

Please sign in to comment.