Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: started typing #264

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions peewee_async/connection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from contextvars import ContextVar
from typing import Optional
from types import TracebackType
from typing import Optional, Type

from peewee_async.pool import PoolBackend
from peewee_async.utils import T_Connection


class ConnectionContext:
def __init__(self, connection):
def __init__(self, connection: T_Connection) -> None:
self.connection = connection
# needs for to know whether begin a transaction or create a savepoint
self.transaction_is_opened = False
Expand All @@ -13,12 +17,12 @@ def __init__(self, connection):


class ConnectionContextManager:
def __init__(self, pool_backend):
def __init__(self, pool_backend: PoolBackend[T_Connection]) -> None:
self.pool_backend = pool_backend
self.connection_context = connection_context.get()
self.resuing_connection = self.connection_context is not None

async def __aenter__(self):
async def __aenter__(self) -> T_Connection:
if self.connection_context is not None:
connection = self.connection_context.connection
else:
Expand All @@ -27,7 +31,13 @@ async def __aenter__(self):
connection_context.set(self.connection_context)
return connection

async def __aexit__(self, *args):
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]
) -> None:
if self.resuing_connection is False:
self.pool_backend.release(self.connection_context.connection)
if self.connection_context is not None:
self.pool_backend.release(self.connection_context.connection)
connection_context.set(None)
9 changes: 4 additions & 5 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import logging
import warnings
from typing import Type
from typing import Type, Optional, Any

import peewee
from playhouse import postgres_ext as ext
Expand All @@ -18,8 +18,10 @@ class AioDatabase:

pool_backend_cls: Type[PoolBackend]

def __init__(self, database, **kwargs):
def __init__(self, database: Optional[str], **kwargs: Any):
super().__init__(database, **kwargs)
if not database:
raise Exception("Deferred initialization is not supported")
self.pool_backend = self.pool_backend_cls(
database=self.database,
**self.connect_params_async
Expand All @@ -28,9 +30,6 @@ def __init__(self, database, **kwargs):
async def aio_connect(self):
"""Set up async connection on default event loop.
"""
if self.deferred:
raise Exception("Error, database not properly initialized "
"before opening connection")
await self.pool_backend.connect()

@property
Expand Down
39 changes: 23 additions & 16 deletions peewee_async/pool.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,68 @@
import abc
import asyncio
from typing import Any, Generic

from .utils import aiopg, aiomysql
from .utils import aiopg, aiomysql, T_Connection


class PoolBackend(metaclass=abc.ABCMeta):
class PoolBackend(Generic[T_Connection], metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, **kwargs):
def __init__(self, *, database: str, **kwargs: Any) -> None:
self.pool = None
self.database = database
self.connect_params = kwargs
self._connection_lock = asyncio.Lock()

@property
def is_connected(self):
return self.pool is not None and self.pool.closed is False
def is_connected(self) -> bool:
if self.pool is not None:
return self.pool.closed is False
return False

def has_acquired_connections(self):
return self.pool is not None and len(self.pool._used) > 0
def has_acquired_connections(self) -> bool:
if self.pool is not None:
return len(self.pool._used) > 0
return False

async def connect(self):
async def connect(self) -> None:
async with self._connection_lock:
if self.is_connected is False:
await self.create()

async def acquire(self):
async def acquire(self) -> T_Connection:
"""Acquire connection from pool.
"""
if self.pool is None:
await self.connect()
assert self.pool is not None, "Pool is not connected"
return await self.pool.acquire()

def release(self, conn):
def release(self, conn: T_Connection) -> None:
"""Release connection to pool.
"""
assert self.pool is not None, "Pool is not connected"
self.pool.release(conn)

@abc.abstractmethod
async def create(self):
async def create(self) -> None:
"""Create connection pool asynchronously.
"""
raise NotImplementedError

async def terminate(self):
async def terminate(self) -> None:
"""Terminate all pool connections.
"""
if self.pool is not None:
self.pool.terminate()
await self.pool.wait_closed()


class PostgresqlPoolBackend(PoolBackend):
class PostgresqlPoolBackend(PoolBackend[aiopg.Connection]):
"""Asynchronous database connection pool.
"""

async def create(self):
async def create(self) -> None:
"""Create connection pool asynchronously.
"""
if "connect_timeout" in self.connect_params:
Expand All @@ -66,11 +73,11 @@ async def create(self):
)


class MysqlPoolBackend(PoolBackend):
class MysqlPoolBackend(PoolBackend[aiomysql.Connection]):
"""Asynchronous database connection pool.
"""

async def create(self):
async def create(self) -> None:
"""Create connection pool asynchronously.
"""
self.pool = await aiomysql.create_pool(
Expand Down
34 changes: 21 additions & 13 deletions peewee_async/transactions.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,56 @@
import uuid
from types import TracebackType
from typing import Optional, Type

from peewee_async.utils import T_Connection


class Transaction:

def __init__(self, connection, is_savepoint=False):
def __init__(self, connection: T_Connection, is_savepoint: bool=False):
self.connection = connection
self.savepoint: Optional[str] = None
if is_savepoint:
self.savepoint = f"PWASYNC__{uuid.uuid4().hex}"
else:
self.savepoint = None

@property
def is_savepoint(self):
def is_savepoint(self) -> bool:
return self.savepoint is not None

async def execute(self, sql):
async def execute(self, sql: str) -> None:
async with self.connection.cursor() as cursor:
await cursor.execute(sql)

async def begin(self):
async def begin(self) -> None:
sql = "BEGIN"
if self.savepoint:
sql = f"SAVEPOINT {self.savepoint}"
return await self.execute(sql)
await self.execute(sql)

async def __aenter__(self):
async def __aenter__(self) -> 'Transaction':
await self.begin()
return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]
) -> None:
if exc_type is not None:
await self.rollback()
else:
await self.commit()

async def commit(self):
async def commit(self) -> None:

sql = "COMMIT"
if self.savepoint:
sql = f"RELEASE SAVEPOINT {self.savepoint}"
return await self.execute(sql)
await self.execute(sql)

async def rollback(self):
async def rollback(self) -> None:
sql = "ROLLBACK"
if self.savepoint:
sql = f"ROLLBACK TO SAVEPOINT {self.savepoint}"
return await self.execute(sql)
await self.execute(sql)
5 changes: 4 additions & 1 deletion peewee_async/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import TypeVar, Any

try:
import aiopg
Expand All @@ -15,4 +16,6 @@
pymysql = None

__log__ = logging.getLogger('peewee.async')
__log__.addHandler(logging.NullHandler())
__log__.addHandler(logging.NullHandler())

T_Connection = TypeVar("T_Connection", bound=Any)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ pytest-asyncio = { version = "^0.21.1", optional = true }
pytest-mock = { version = "^3.14.0", optional = true }
sphinx = { version = "^7.1.2", optional = true }
sphinx-rtd-theme = { version = "^1.3.0rc1", optional = true }
mypy = { version = "^1.10.1", optional = true }

[tool.poetry.extras]
postgresql = ["aiopg"]
mysql = ["aiomysql", "cryptography"]
develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock"]
develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock", "mypy"]
docs = ["aiopg", "aiomysql", "cryptography", "sphinx", "sphinx-rtd-theme"]

[build-system]
Expand Down
16 changes: 16 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,18 @@
[tool:pytest]
asyncio_mode = auto

[mypy]
python_version = 3.10
ignore_missing_imports = True
no_implicit_optional = True
strict_equality = True
check_untyped_defs = True
warn_redundant_casts = True
warn_unused_configs = True
warn_unused_ignores = True
warn_return_any = True
disallow_any_generics = True
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
exclude = (venv|load-testing|examples)
20 changes: 0 additions & 20 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,26 +173,6 @@ async def get_conn(manager):
assert manager.is_connected is False


@pytest.mark.parametrize(
"params, db_cls",
[
(DB_DEFAULTS[name], db_cls) for name, db_cls in DB_CLASSES.items()
]

)
async def test_deferred_init(params, db_cls):

database = db_cls(None)
assert database.deferred is True

database.init(**params)
assert database.deferred is False

TestModel._meta.database = database
TestModel.create_table(True)
TestModel.drop_table(True)


@pytest.mark.parametrize(
"params, db_cls",
[
Expand Down
Loading