From e9d7f77e3054d960960b01b699372b3b9738ba7c Mon Sep 17 00:00:00 2001 From: Gorshkov Nikolay Date: Thu, 4 Jul 2024 12:17:15 +0500 Subject: [PATCH] feat: added compoundselect support --- peewee_async.py | 24 ++++++++++- tests/aio_model/test_selecting.py | 72 ++++++++++++++++++++++++------- 2 files changed, 79 insertions(+), 17 deletions(-) diff --git a/peewee_async.py b/peewee_async.py index f34fbaa..522530f 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -774,15 +774,35 @@ async def aio_exists(self, database): clone._offset = None return bool(await clone.aio_scalar()) + def union_all(self, rhs): + return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) + __add__ = union_all + + def union(self, rhs): + return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs) + __or__ = union + + def intersect(self, rhs): + return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) + __and__ = intersect + + def except_(self, rhs): + return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) + __sub__ = except_ + def aio_prefetch(self, *subqueries, **kwargs): return aio_prefetch(self, *subqueries, **kwargs) -class AioSelect(peewee.Select, AioSelectMixin): +class AioSelect(AioSelectMixin, peewee.Select): + pass + + +class AioModelSelect(AioSelectMixin, peewee.ModelSelect): pass -class AioModelSelect(peewee.ModelSelect, AioSelectMixin): +class AioModelCompoundSelectQuery(AioSelectMixin, peewee.ModelCompoundSelectQuery): pass diff --git a/tests/aio_model/test_selecting.py b/tests/aio_model/test_selecting.py index be69191..93722e1 100644 --- a/tests/aio_model/test_selecting.py +++ b/tests/aio_model/test_selecting.py @@ -1,7 +1,5 @@ -import peewee - -from peewee_async import AioModelRaw -from tests.conftest import manager_for_all_dbs, dbs_all +from peewee_async import AioModelRaw, AioModelCompoundSelectQuery +from tests.conftest import dbs_all from tests.models import TestModel, TestModelAlpha, TestModelBeta @@ -31,18 +29,62 @@ async def test_raw_select(db): assert list(result) == [obj1, obj2] -@manager_for_all_dbs -async def test_select_compound(manager): - obj1 = await manager.create(TestModel, text="Test 1") - obj2 = await manager.create(TestModel, text="Test 2") +@dbs_all +async def test_union_all(db): + obj1 = await TestModel.aio_create(text="1") + obj2 = await TestModel.aio_create(text="2") + query = ( + TestModel.select().where(TestModel.id == obj1.id) + + TestModel.select().where(TestModel.id == obj2.id) + + TestModel.select().where(TestModel.id == obj2.id) + ) + result = await query.aio_execute() + assert sorted(r.text for r in result) == ["1", "2", "2"] + + +@dbs_all +async def test_union(db): + obj1 = await TestModel.aio_create(text="1") + obj2 = await TestModel.aio_create(text="2") query = ( TestModel.select().where(TestModel.id == obj1.id) | + TestModel.select().where(TestModel.id == obj2.id) | TestModel.select().where(TestModel.id == obj2.id) ) - assert isinstance(query, peewee.ModelCompoundSelectQuery) - # NOTE: Two `AioModelSelect` when joining via `|` produce `ModelCompoundSelectQuery` - # without `aio_execute()` method, so using `database.aio_execute()` here. - result = await manager.database.aio_execute(query) - assert len(list(result)) == 2 - assert obj1 in list(result) - assert obj2 in list(result) + assert isinstance(query, AioModelCompoundSelectQuery) + result = await query.aio_execute() + assert sorted(r.text for r in result) == ["1", "2"] + + +@dbs_all +async def test_intersect(db): + await TestModel.aio_create(text="1") + await TestModel.aio_create(text="2") + await TestModel.aio_create(text="3") + query = ( + TestModel.select().where( + (TestModel.text == "1") | (TestModel.text == "2") + ) & + TestModel.select().where( + (TestModel.text == "2") | (TestModel.text == "3") + ) + ) + result = await query.aio_execute() + assert sorted(r.text for r in result) == ["2"] + + +@dbs_all +async def test_except(db): + await TestModel.aio_create(text="1") + await TestModel.aio_create(text="2") + await TestModel.aio_create(text="3") + query = ( + TestModel.select().where( + (TestModel.text == "1") | (TestModel.text == "2") | (TestModel.text == "3") + ) - + TestModel.select().where( + (TestModel.text == "2") + ) + ) + result = await query.aio_execute() + assert sorted(r.text for r in result) == ["1", "3"] \ No newline at end of file