diff --git a/peewee_async.py b/peewee_async.py index de4aa87..b3cb385 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -397,7 +397,8 @@ async def aio_execute(self, query, fetch_results=None): # To make `Database.aio_execute` compatible with peewee's sync queries we # apply optional patching, it will do nothing for Aio-counterparts: _patch_query_with_compat_methods(query, None) - sql, params = query.sql() + ctx = self.get_sql_context() + sql, params = ctx.sql(query).query() fetch_results = fetch_results or getattr(query, 'fetch_results', None) return await self.aio_execute_sql(sql, params, fetch_results=fetch_results) @@ -694,7 +695,7 @@ async def fetch_results(self, cursor): return await self.make_async_query_wrapper(cursor) -class AioModelSelect(peewee.ModelSelect, AioQueryMixin): +class AioSelectMixin(AioQueryMixin): async def fetch_results(self, cursor): return await self.make_async_query_wrapper(cursor) @@ -723,6 +724,28 @@ async def aio_get(self, database=None): 'not exist:\nSQL: %s\nParams: %s' % (clone.model, sql, params)) + @peewee.database_required + async def aio_count(self, database, clear_limit=False): + clone = self.order_by().alias('_wrapped') + if clear_limit: + clone._limit = clone._offset = None + try: + if clone._having is None and clone._group_by is None and \ + clone._windows is None and clone._distinct is None and \ + clone._simple_distinct is not True: + clone = clone.select(peewee.SQL('1')) + except AttributeError: + pass + return await AioSelect([clone], [peewee.fn.COUNT(peewee.SQL('1'))]).aio_scalar(database) + + +class AioSelect(peewee.Select, AioSelectMixin): + pass + + +class AioModelSelect(peewee.ModelSelect, AioSelectMixin): + pass + class AioModel(peewee.Model): """Async version of **peewee.Model** that allows to execute queries asynchronously diff --git a/peewee_async_compat.py b/peewee_async_compat.py index d90e3fa..c1821af 100644 --- a/peewee_async_compat.py +++ b/peewee_async_compat.py @@ -78,6 +78,7 @@ def _patch_query_with_compat_methods(query, async_query_cls): if async_query_cls is AioModelSelect: query.aio_get = partial(async_query_cls.aio_get, query) query.aio_scalar = partial(async_query_cls.aio_scalar, query) + query.aio_count = partial(async_query_cls.aio_count, query) def _query_db(query): @@ -94,25 +95,13 @@ async def count(query, clear_limit=False): :return: number of objects in `select()` query """ - database = _query_db(query) - clone = query.clone() - if query._distinct or query._group_by or query._limit or query._offset: - if clear_limit: - clone._limit = clone._offset = None - sql, params = clone.sql() - wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql - async def fetch_results(cursor): - row = await cursor.fetchone() - if row: - return row[0] - else: - return row - result = await database.aio_execute_sql(wrapped, params, fetch_results) - return result or 0 - else: - clone._returning = [peewee.fn.Count(peewee.SQL('*'))] - clone._order_by = None - return (await scalar(clone)) or 0 + from peewee_async import AioModelSelect # noqa + warnings.warn( + "`count` is deprecated, use `query.aio_count` method.", + DeprecationWarning + ) + _patch_query_with_compat_methods(query, AioModelSelect) + return await query.aio_count(clear_limit=clear_limit) async def prefetch(sq, *subqueries, prefetch_type): diff --git a/tests/aio_model/test_shortcuts.py b/tests/aio_model/test_shortcuts.py index 0ec756b..c0fcc4e 100644 --- a/tests/aio_model/test_shortcuts.py +++ b/tests/aio_model/test_shortcuts.py @@ -43,3 +43,21 @@ async def test_aio_scalar(db): ).aio_scalar(as_tuple=True) == (2, 1) assert await TestModel.select().aio_scalar() is None + + +@dbs_all +async def test_count_query(db): + + for num in range(5): + await IntegerTestModel.aio_create(num=num) + count = await IntegerTestModel.select().limit(3).aio_count() + assert count == 3 + + +@dbs_all +async def test_count_query_clear_limit(db): + + for num in range(5): + await IntegerTestModel.aio_create(num=num) + count = await IntegerTestModel.select().limit(3).aio_count(clear_limit=True) + assert count == 5 diff --git a/tests/compat/test_shortcuts.py b/tests/compat/test_shortcuts.py new file mode 100644 index 0000000..a27717f --- /dev/null +++ b/tests/compat/test_shortcuts.py @@ -0,0 +1,71 @@ +import uuid + +import peewee + +from tests.conftest import manager_for_all_dbs +from tests.models import CompatTestModel + + +@manager_for_all_dbs +async def test_get_or_none(manager): + """Test get_or_none manager function.""" + text1 = "Test %s" % uuid.uuid4() + text2 = "Test %s" % uuid.uuid4() + + obj1 = await manager.create(CompatTestModel, text=text1) + obj2 = await manager.get_or_none(CompatTestModel, text=text1) + obj3 = await manager.get_or_none(CompatTestModel, text=text2) + + assert obj1 == obj2 + assert obj1 is not None + assert obj2 is not None + assert obj3 is None + + +@manager_for_all_dbs +async def test_count_query_with_limit(manager): + text = "Test %s" % uuid.uuid4() + await manager.create(CompatTestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(CompatTestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(CompatTestModel, text=text) + + count = await manager.count(CompatTestModel.select().limit(1)) + assert count == 1 + + +@manager_for_all_dbs +async def test_count_query(manager): + text = "Test %s" % uuid.uuid4() + await manager.create(CompatTestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(CompatTestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(CompatTestModel, text=text) + + count = await manager.count(CompatTestModel.select()) + assert count == 3 + + +@manager_for_all_dbs +async def test_scalar_query(manager): + + text = "Test %s" % uuid.uuid4() + await manager.create(CompatTestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(CompatTestModel, text=text) + + fn = peewee.fn.Count(CompatTestModel.id) + count = await manager.scalar(CompatTestModel.select(fn)) + + assert count == 2 + + +@manager_for_all_dbs +async def test_create_obj(manager): + + text = "Test %s" % uuid.uuid4() + obj = await manager.create(CompatTestModel, text=text) + assert obj is not None + assert obj.text == text diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py index 14e5cdb..562b21e 100644 --- a/tests/test_shortcuts.py +++ b/tests/test_shortcuts.py @@ -43,62 +43,6 @@ async def test_prefetch(manager, prefetch_type): assert tuple(result[0].betas[0].gammas) == (gamma_111, gamma_112) -@manager_for_all_dbs -async def test_get_or_none(manager): - """Test get_or_none manager function.""" - text1 = "Test %s" % uuid.uuid4() - text2 = "Test %s" % uuid.uuid4() - - obj1 = await manager.create(TestModel, text=text1) - obj2 = await manager.get_or_none(TestModel, text=text1) - obj3 = await manager.get_or_none(TestModel, text=text2) - - assert obj1 == obj2 - assert obj1 is not None - assert obj2 is not None - assert obj3 is None - - -@manager_for_all_dbs -async def test_count_query_with_limit(manager): - text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) - text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) - text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) - - count = await manager.count(TestModel.select().limit(1)) - assert count == 1 - - -@manager_for_all_dbs -async def test_count_query(manager): - text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) - text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) - text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) - - count = await manager.count(TestModel.select()) - assert count == 3 - - -@manager_for_all_dbs -async def test_scalar_query(manager): - - text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) - text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) - - fn = peewee.fn.Count(TestModel.id) - count = await manager.scalar(TestModel.select(fn)) - - assert count == 2 - - @manager_for_all_dbs async def test_delete_obj(manager): text = "Test %s" % uuid.uuid4() @@ -124,15 +68,6 @@ async def test_update_obj(manager): assert obj2.text == "Test update object" -@manager_for_all_dbs -async def test_create_obj(manager): - - text = "Test %s" % uuid.uuid4() - obj = await manager.create(TestModel, text=text) - assert obj is not None - assert obj.text == text - - @manager_for_all_dbs async def test_create_or_get(manager): text = "Test %s" % uuid.uuid4()