Skip to content

Commit

Permalink
chore: module is splited by files
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed Jul 11, 2024
1 parent 519443b commit 669b958
Show file tree
Hide file tree
Showing 13 changed files with 1,049 additions and 1,054 deletions.
1,043 changes: 0 additions & 1,043 deletions peewee_async.py

This file was deleted.

68 changes: 68 additions & 0 deletions peewee_async/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
peewee-async
============
Asynchronous interface for `peewee`_ ORM powered by `asyncio`_:
https://github.com/05bit/peewee-async
.. _peewee: https://github.com/coleifer/peewee
.. _asyncio: https://docs.python.org/3/library/asyncio.html
Licensed under The MIT License (MIT)
Copyright (c) 2014, Alexey Kinëv <[email protected]>
"""
from importlib.metadata import version

from playhouse.db_url import register_database

from peewee_async_compat import count, execute, prefetch, scalar, savepoint, atomic, transaction, Manager
from .aio_model import aio_prefetch, AioModel
from .connection import connection_context
from .databases import (
PooledPostgresqlDatabase,
PooledPostgresqlExtDatabase,
PooledMySQLDatabase,
PostgresqlDatabase,
MySQLDatabase,
PostgresqlExtDatabase
)
from .pool import PostgresqlPoolBackend, MysqlPoolBackend
from .transactions import Transaction

__version__ = version('peewee-async')


__all__ = [
'PooledPostgresqlDatabase',
'PooledPostgresqlExtDatabase',
'PooledMySQLDatabase',
'Transaction',
'AioModel',
'aio_prefetch',
'connection_context',
'PostgresqlPoolBackend',
'MysqlPoolBackend',

# Compatibility API (deprecated in v1.0 release)
'Manager',
'execute',
'count',
'scalar',
'prefetch',
'atomic',
'transaction',
'savepoint',
]

register_database(PooledPostgresqlDatabase, 'postgres+pool+async', 'postgresql+pool+async')
register_database(PooledPostgresqlExtDatabase, 'postgresext+pool+async', 'postgresqlext+pool+async')
register_database(PooledMySQLDatabase, 'mysql+pool+async')


# DEPRECATED Databases

register_database(PostgresqlDatabase, 'postgres+async', 'postgresql+async')
register_database(MySQLDatabase, 'mysql+async')
register_database(PostgresqlExtDatabase, 'postgresext+async', 'postgresqlext+async')
314 changes: 314 additions & 0 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
import peewee

from .result_wrappers import AsyncQueryWrapper


async def aio_prefetch(sq, *subqueries, prefetch_type):
"""Asynchronous version of the `prefetch()` from peewee."""
if not subqueries:
return sq

fixed_queries = peewee.prefetch_add_subquery(sq, subqueries, prefetch_type)
deps = {}
rel_map = {}

for pq in reversed(fixed_queries):
query_model = pq.model
if pq.fields:
for rel_model in pq.rel_models:
rel_map.setdefault(rel_model, [])
rel_map[rel_model].append(pq)

deps[query_model] = {}
id_map = deps[query_model]
has_relations = bool(rel_map.get(query_model))

result = await pq.query.aio_execute()

for instance in result:
if pq.fields:
pq.store_instance(instance, id_map)
if has_relations:
for rel in rel_map[query_model]:
rel.populate_instance(instance, deps[rel.model])

return result


class AioQueryMixin:
@peewee.database_required
async def aio_execute(self, database):
return await database.aio_execute(self)

async def make_async_query_wrapper(self, cursor):
return await AsyncQueryWrapper.make_for_all_rows(cursor, self)


class AioModelDelete(peewee.ModelDelete, AioQueryMixin):
async def fetch_results(self, cursor):
if self._returning:
return await self.make_async_query_wrapper(cursor)
return cursor.rowcount


class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin):

async def fetch_results(self, cursor):
if self._returning:
return await self.make_async_query_wrapper(cursor)
return cursor.rowcount


class AioModelInsert(peewee.ModelInsert, AioQueryMixin):
async def fetch_results(self, cursor):
if self._returning is not None and len(self._returning) > 1:
return await self.make_async_query_wrapper(cursor)

if self._returning:
row = await cursor.fetchone()
return row[0] if row else None
else:
return await self._database.last_insert_id_async(cursor)


class AioModelRaw(peewee.ModelRaw, AioQueryMixin):
async def fetch_results(self, cursor):
return await self.make_async_query_wrapper(cursor)


class AioSelectMixin(AioQueryMixin):

async def fetch_results(self, cursor):
return await self.make_async_query_wrapper(cursor)

@peewee.database_required
async def aio_scalar(self, database, as_tuple=False):
"""
Get single value from ``select()`` query, i.e. for aggregation.
:return: result is the same as after sync ``query.scalar()`` call
"""
async def fetch_results(cursor):
return await cursor.fetchone()

rows = await database.aio_execute(self, fetch_results=fetch_results)

return rows[0] if rows and not as_tuple else rows

async def aio_get(self, database=None):
clone = self.paginate(1, 1)
try:
return (await clone.aio_execute(database))[0]
except IndexError:
sql, params = clone.sql()
raise self.model.DoesNotExist('%s instance matching query does '
'not exist:\nSQL: %s\nParams: %s' %
(clone.model, sql, params))

@peewee.database_required
async def aio_count(self, database, clear_limit=False):
clone = self.order_by().alias('_wrapped')
if clear_limit:
clone._limit = clone._offset = None
try:
if clone._having is None and clone._group_by is None and \
clone._windows is None and clone._distinct is None and \
clone._simple_distinct is not True:
clone = clone.select(peewee.SQL('1'))
except AttributeError:
pass
return await AioSelect([clone], [peewee.fn.COUNT(peewee.SQL('1'))]).aio_scalar(database)

@peewee.database_required
async def aio_exists(self, database):
clone = self.columns(peewee.SQL('1'))
clone._limit = 1
clone._offset = None
return bool(await clone.aio_scalar())

def union_all(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs)
__add__ = union_all

def union(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs)
__or__ = union

def intersect(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs)
__and__ = intersect

def except_(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs)
__sub__ = except_

def aio_prefetch(self, *subqueries, **kwargs):
return aio_prefetch(self, *subqueries, **kwargs)


class AioSelect(AioSelectMixin, peewee.Select):
pass


class AioModelSelect(AioSelectMixin, peewee.ModelSelect):
pass


class AioModelCompoundSelectQuery(AioSelectMixin, peewee.ModelCompoundSelectQuery):
pass


class AioModel(peewee.Model):
"""Async version of **peewee.Model** that allows to execute queries asynchronously
with **aio_execute** method
Example::
class User(peewee_async.AioModel):
username = peewee.CharField(max_length=40, unique=True)
await User.select().where(User.username == 'admin').aio_execute()
Also it provides async versions of **peewee.Model** shortcuts
Example::
user = await User.aio_get(User.username == 'user')
"""

@classmethod
def select(cls, *fields):
is_default = not fields
if not fields:
fields = cls._meta.sorted_fields
return AioModelSelect(cls, fields, is_default=is_default)

@classmethod
def update(cls, __data=None, **update):
return AioModelUpdate(cls, cls._normalize_data(__data, update))

@classmethod
def insert(cls, __data=None, **insert):
return AioModelInsert(cls, cls._normalize_data(__data, insert))

@classmethod
def insert_many(cls, rows, fields=None):
return AioModelInsert(cls, insert=rows, columns=fields)

@classmethod
def insert_from(cls, query, fields):
columns = [getattr(cls, field) if isinstance(field, str)
else field for field in fields]
return AioModelInsert(cls, insert=query, columns=columns)

@classmethod
def raw(cls, sql, *params):
return AioModelRaw(cls, sql, params)

@classmethod
def delete(cls):
return AioModelDelete(cls)

async def aio_delete_instance(self, recursive=False, delete_nullable=False):
if recursive:
dependencies = self.dependencies(delete_nullable)
for query, fk in reversed(list(dependencies)):
model = fk.model
if fk.null and not delete_nullable:
await model.update(**{fk.name: None}).where(query).aio_execute()
else:
await model.delete().where(query).aio_execute()
return await type(self).delete().where(self._pk_expr()).aio_execute()

async def aio_save(self, force_insert=False, only=None):
field_dict = self.__data__.copy()
if self._meta.primary_key is not False:
pk_field = self._meta.primary_key
pk_value = self._pk
else:
pk_field = pk_value = None
if only is not None:
field_dict = self._prune_fields(field_dict, only)
elif self._meta.only_save_dirty and not force_insert:
field_dict = self._prune_fields(field_dict, self.dirty_fields)
if not field_dict:
self._dirty.clear()
return False

self._populate_unsaved_relations(field_dict)
rows = 1

if self._meta.auto_increment and pk_value is None:
field_dict.pop(pk_field.name, None)

if pk_value is not None and not force_insert:
if self._meta.composite_key:
for pk_part_name in pk_field.field_names:
field_dict.pop(pk_part_name, None)
else:
field_dict.pop(pk_field.name, None)
if not field_dict:
raise ValueError('no data to save!')
rows = await self.update(**field_dict).where(self._pk_expr()).aio_execute()
elif pk_field is not None:
pk = await self.insert(**field_dict).aio_execute()
if pk is not None and (self._meta.auto_increment or
pk_value is None):
self._pk = pk
# Although we set the primary-key, do not mark it as dirty.
self._dirty.discard(pk_field.name)
else:
await self.insert(**field_dict).aio_execute()

self._dirty -= set(field_dict) # Remove any fields we saved.
return rows

@classmethod
async def aio_get(cls, *query, **filters):
"""Async version of **peewee.Model.get**"""
sq = cls.select()
if query:
if len(query) == 1 and isinstance(query[0], int):
sq = sq.where(cls._meta.primary_key == query[0])
else:
sq = sq.where(*query)
if filters:
sq = sq.filter(**filters)
return await sq.aio_get()

@classmethod
async def aio_get_or_none(cls, *query, **filters):
"""
Async version of **peewee.Model.get_or_none**
"""
try:
return await cls.aio_get(*query, **filters)
except cls.DoesNotExist:
return None

@classmethod
async def aio_create(cls, **query):
inst = cls(**query)
await inst.aio_save(force_insert=True)
return inst

@classmethod
async def aio_get_or_create(cls, **kwargs):
defaults = kwargs.pop('defaults', {})
query = cls.select()
for field, value in kwargs.items():
query = query.where(getattr(cls, field) == value)

try:
return await query.aio_get(), False
except cls.DoesNotExist:
try:
if defaults:
kwargs.update(defaults)
async with cls._meta.database.aio_atomic():
return await cls.aio_create(**kwargs), True
except peewee.IntegrityError as exc:
try:
return await query.aio_get(), False
except cls.DoesNotExist:
raise exc
Loading

0 comments on commit 669b958

Please sign in to comment.