Skip to content

Commit

Permalink
feat: typing database file
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed Jul 16, 2024
1 parent 1bce6fe commit a13b018
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import contextlib
import logging
import warnings
from typing import Type, Optional, Any
from typing import Type, Optional, Any, AsyncIterator, Iterator

import peewee
from playhouse import postgres_ext as ext
from .utils import psycopg2, aiopg, pymysql, aiomysql, __log__

from .pool import PoolBackend, PostgresqlPoolBackend, MysqlPoolBackend
from peewee_async_compat import _patch_query_with_compat_methods, savepoint
from .connection import connection_context, ConnectionContextManager
from .pool import PoolBackend, PostgresqlPoolBackend, MysqlPoolBackend
from .transactions import Transaction
from peewee_async_compat import _patch_query_with_compat_methods, savepoint
from .utils import psycopg2, aiopg, pymysql, aiomysql, __log__


class AioDatabase:
_allow_sync = True # whether sync queries are allowed

pool_backend_cls: Type[PoolBackend]

def __init__(self, database: Optional[str], **kwargs: Any):
def __init__(self, database: Optional[str], **kwargs: Any) -> None:
super().__init__(database, **kwargs)
if not database:
raise Exception("Deferred initialization is not supported")
Expand All @@ -27,22 +27,22 @@ def __init__(self, database: Optional[str], **kwargs: Any):
**self.connect_params_async
)

async def aio_connect(self):
async def aio_connect(self) -> None:
"""Set up async connection on default event loop.
"""
await self.pool_backend.connect()

@property
def is_connected(self):
def is_connected(self) -> bool:
return self.pool_backend.is_connected

async def aio_close(self):
async def aio_close(self) -> None:
"""Close async connection.
"""
await self.pool_backend.terminate()

@contextlib.asynccontextmanager
async def aio_atomic(self):
async def aio_atomic(self) -> AsyncIterator[None]:
"""Similar to peewee `Database.atomic()` method, but returns
asynchronous context manager.
"""
Expand All @@ -57,14 +57,14 @@ async def aio_atomic(self):
if begin_transaction is True:
_connection_context.transaction_is_opened = False

def set_allow_sync(self, value):
def set_allow_sync(self, value: bool) -> None:
"""Allow or forbid sync queries for the database. See also
the :meth:`.allow_sync()` context manager.
"""
self._allow_sync = value

@contextlib.contextmanager
def allow_sync(self):
def allow_sync(self) -> Iterator[None]:
"""Allow sync queries within context. Close sync
connection on exit if connected.
Expand Down Expand Up @@ -129,7 +129,7 @@ async def aio_execute(self, query, fetch_results=None):
return await self.aio_execute_sql(sql, params, fetch_results=fetch_results)

#### Deprecated methods ####
def __setattr__(self, name, value):
def __setattr__(self, name, value) -> None:
if name == 'allow_sync':
warnings.warn(
"`.allow_sync` setter is deprecated, use either the "
Expand All @@ -139,7 +139,7 @@ def __setattr__(self, name, value):
else:
super().__setattr__(name, value)

def atomic_async(self):
def atomic_async(self) -> Any:
"""Similar to peewee `Database.atomic()` method, but returns
asynchronous context manager.
"""
Expand All @@ -149,7 +149,7 @@ def atomic_async(self):
)
return self.aio_atomic()

def savepoint_async(self, sid=None):
def savepoint_async(self, sid=None) -> Any:
"""Similar to peewee `Database.savepoint()` method, but returns
asynchronous context manager.
"""
Expand All @@ -159,21 +159,21 @@ def savepoint_async(self, sid=None):
)
return savepoint(self, sid=sid)

async def connect_async(self):
async def connect_async(self) -> None:
warnings.warn(
"`connect_async` is deprecated, use `aio_connect` instead.",
DeprecationWarning
)
await self.aio_connect()

async def close_async(self):
async def close_async(self) -> None:
warnings.warn(
"`close_async` is deprecated, use `aio_close` instead.",
DeprecationWarning
)
await self.aio_close()

def transaction_async(self):
def transaction_async(self) -> Any:
"""Similar to peewee `Database.transaction()` method, but returns
asynchronous context manager.
"""
Expand All @@ -194,7 +194,7 @@ class AioPostgresqlMixin(AioDatabase):
if psycopg2:
Error = psycopg2.Error

def init_async(self, enable_json=False, enable_hstore=False):
def init_async(self, enable_json: bool = False, enable_hstore: bool =False) -> None:
if not aiopg:
raise Exception("Error, aiopg is not installed!")
self._enable_json = enable_json
Expand Down Expand Up @@ -238,7 +238,7 @@ class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""
def init(self, database, **kwargs):
def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
connection_timeout = kwargs.pop('connection_timeout', None)
Expand Down Expand Up @@ -273,7 +273,7 @@ class PooledPostgresqlExtDatabase(
https://peewee.readthedocs.io/en/latest/peewee/playhouse.html#PostgresqlExtDatabase
"""

def init(self, database, **kwargs):
def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
connection_timeout = kwargs.pop('connection_timeout', None)
Expand Down Expand Up @@ -308,7 +308,7 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
if pymysql:
Error = pymysql.Error

def init(self, database, **kwargs):
def init(self, database: Optional[str], **kwargs: Any) -> None:
if not aiomysql:
raise Exception("Error, aiomysql is not installed!")
self.min_connections = kwargs.pop('min_connections', 1)
Expand Down Expand Up @@ -347,7 +347,7 @@ class PostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""
def init(self, database, **kwargs):
def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`PostgresqlDatabase` is deprecated, use `PooledPostgresqlDatabase` instead.",
DeprecationWarning
Expand All @@ -369,7 +369,7 @@ class MySQLDatabase(PooledMySQLDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
"""
def init(self, database, **kwargs):
def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`MySQLDatabase` is deprecated, use `PooledMySQLDatabase` instead.",
DeprecationWarning
Expand All @@ -394,7 +394,7 @@ class PostgresqlExtDatabase(AioPostgresqlMixin, ext.PostgresqlExtDatabase):
https://peewee.readthedocs.io/en/latest/peewee/playhouse.html#PostgresqlExtDatabase
"""

def init(self, database, **kwargs):
def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`PostgresqlExtDatabase` is deprecated, use `PooledPostgresqlExtDatabase` instead.",
DeprecationWarning
Expand Down

0 comments on commit a13b018

Please sign in to comment.