From abc676e9f972f7e4251f2ef7f6e860c56b8f74aa Mon Sep 17 00:00:00 2001 From: Gorshkov Nikolay Date: Mon, 15 Jul 2024 13:04:11 +0500 Subject: [PATCH] feat: started typing --- peewee_async/connection.py | 22 ++++++++++++++------ peewee_async/databases.py | 9 ++++----- peewee_async/pool.py | 39 +++++++++++++++++++++--------------- peewee_async/transactions.py | 34 +++++++++++++++++++------------ peewee_async/utils.py | 5 ++++- pyproject.toml | 3 ++- setup.cfg | 16 +++++++++++++++ tests/test_common.py | 20 ------------------ 8 files changed, 86 insertions(+), 62 deletions(-) diff --git a/peewee_async/connection.py b/peewee_async/connection.py index b0bccad..ffef4cd 100644 --- a/peewee_async/connection.py +++ b/peewee_async/connection.py @@ -1,9 +1,13 @@ from contextvars import ContextVar -from typing import Optional +from types import TracebackType +from typing import Optional, Type + +from peewee_async.pool import PoolBackend +from peewee_async.utils import T_Connection class ConnectionContext: - def __init__(self, connection): + def __init__(self, connection: T_Connection) -> None: self.connection = connection # needs for to know whether begin a transaction or create a savepoint self.transaction_is_opened = False @@ -13,12 +17,12 @@ def __init__(self, connection): class ConnectionContextManager: - def __init__(self, pool_backend): + def __init__(self, pool_backend: PoolBackend[T_Connection]) -> None: self.pool_backend = pool_backend self.connection_context = connection_context.get() self.resuing_connection = self.connection_context is not None - async def __aenter__(self): + async def __aenter__(self) -> T_Connection: if self.connection_context is not None: connection = self.connection_context.connection else: @@ -27,7 +31,13 @@ async def __aenter__(self): connection_context.set(self.connection_context) return connection - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType] + ) -> None: if self.resuing_connection is False: - self.pool_backend.release(self.connection_context.connection) + if self.connection_context is not None: + self.pool_backend.release(self.connection_context.connection) connection_context.set(None) diff --git a/peewee_async/databases.py b/peewee_async/databases.py index 661ad6e..a7eb1f0 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -1,7 +1,7 @@ import contextlib import logging import warnings -from typing import Type +from typing import Type, Optional, Any import peewee from playhouse import postgres_ext as ext @@ -18,8 +18,10 @@ class AioDatabase: pool_backend_cls: Type[PoolBackend] - def __init__(self, database, **kwargs): + def __init__(self, database: Optional[str], **kwargs: Any): super().__init__(database, **kwargs) + if not database: + raise Exception("Deferred initialization is not supported") self.pool_backend = self.pool_backend_cls( database=self.database, **self.connect_params_async @@ -28,9 +30,6 @@ def __init__(self, database, **kwargs): async def aio_connect(self): """Set up async connection on default event loop. """ - if self.deferred: - raise Exception("Error, database not properly initialized " - "before opening connection") await self.pool_backend.connect() @property diff --git a/peewee_async/pool.py b/peewee_async/pool.py index 0d4f1d9..ac0eae7 100644 --- a/peewee_async/pool.py +++ b/peewee_async/pool.py @@ -1,49 +1,56 @@ import abc import asyncio +from typing import Any, Generic -from .utils import aiopg, aiomysql +from .utils import aiopg, aiomysql, T_Connection -class PoolBackend(metaclass=abc.ABCMeta): +class PoolBackend(Generic[T_Connection], metaclass=abc.ABCMeta): """Asynchronous database connection pool. """ - def __init__(self, *, database=None, **kwargs): + def __init__(self, *, database: str, **kwargs: Any) -> None: self.pool = None self.database = database self.connect_params = kwargs self._connection_lock = asyncio.Lock() @property - def is_connected(self): - return self.pool is not None and self.pool.closed is False + def is_connected(self) -> bool: + if self.pool is not None: + return self.pool.closed is False + return False - def has_acquired_connections(self): - return self.pool is not None and len(self.pool._used) > 0 + def has_acquired_connections(self) -> bool: + if self.pool is not None: + return len(self.pool._used) > 0 + return False - async def connect(self): + async def connect(self) -> None: async with self._connection_lock: if self.is_connected is False: await self.create() - async def acquire(self): + async def acquire(self) -> T_Connection: """Acquire connection from pool. """ if self.pool is None: await self.connect() + assert self.pool is not None, "Pool is not connected" return await self.pool.acquire() - def release(self, conn): + def release(self, conn: T_Connection) -> None: """Release connection to pool. """ + assert self.pool is not None, "Pool is not connected" self.pool.release(conn) @abc.abstractmethod - async def create(self): + async def create(self) -> None: """Create connection pool asynchronously. """ raise NotImplementedError - async def terminate(self): + async def terminate(self) -> None: """Terminate all pool connections. """ if self.pool is not None: @@ -51,11 +58,11 @@ async def terminate(self): await self.pool.wait_closed() -class PostgresqlPoolBackend(PoolBackend): +class PostgresqlPoolBackend(PoolBackend[aiopg.Connection]): """Asynchronous database connection pool. """ - async def create(self): + async def create(self) -> None: """Create connection pool asynchronously. """ if "connect_timeout" in self.connect_params: @@ -66,11 +73,11 @@ async def create(self): ) -class MysqlPoolBackend(PoolBackend): +class MysqlPoolBackend(PoolBackend[aiomysql.Connection]): """Asynchronous database connection pool. """ - async def create(self): + async def create(self) -> None: """Create connection pool asynchronously. """ self.pool = await aiomysql.create_pool( diff --git a/peewee_async/transactions.py b/peewee_async/transactions.py index 186b43f..6520803 100644 --- a/peewee_async/transactions.py +++ b/peewee_async/transactions.py @@ -1,48 +1,56 @@ import uuid +from types import TracebackType +from typing import Optional, Type + +from peewee_async.utils import T_Connection class Transaction: - def __init__(self, connection, is_savepoint=False): + def __init__(self, connection: T_Connection, is_savepoint: bool=False): self.connection = connection + self.savepoint: Optional[str] = None if is_savepoint: self.savepoint = f"PWASYNC__{uuid.uuid4().hex}" - else: - self.savepoint = None @property - def is_savepoint(self): + def is_savepoint(self) -> bool: return self.savepoint is not None - async def execute(self, sql): + async def execute(self, sql: str) -> None: async with self.connection.cursor() as cursor: await cursor.execute(sql) - async def begin(self): + async def begin(self) -> None: sql = "BEGIN" if self.savepoint: sql = f"SAVEPOINT {self.savepoint}" - return await self.execute(sql) + await self.execute(sql) - async def __aenter__(self): + async def __aenter__(self) -> 'Transaction': await self.begin() return self - async def __aexit__(self, exc_type, exc_value, exc_tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType] + ) -> None: if exc_type is not None: await self.rollback() else: await self.commit() - async def commit(self): + async def commit(self) -> None: sql = "COMMIT" if self.savepoint: sql = f"RELEASE SAVEPOINT {self.savepoint}" - return await self.execute(sql) + await self.execute(sql) - async def rollback(self): + async def rollback(self) -> None: sql = "ROLLBACK" if self.savepoint: sql = f"ROLLBACK TO SAVEPOINT {self.savepoint}" - return await self.execute(sql) + await self.execute(sql) diff --git a/peewee_async/utils.py b/peewee_async/utils.py index 22fbbf6..59da329 100644 --- a/peewee_async/utils.py +++ b/peewee_async/utils.py @@ -1,4 +1,5 @@ import logging +from typing import TypeVar, Any try: import aiopg @@ -15,4 +16,6 @@ pymysql = None __log__ = logging.getLogger('peewee.async') -__log__.addHandler(logging.NullHandler()) \ No newline at end of file +__log__.addHandler(logging.NullHandler()) + +T_Connection = TypeVar("T_Connection", bound=Any) diff --git a/pyproject.toml b/pyproject.toml index 0398fcd..bcb6278 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,11 +19,12 @@ pytest-asyncio = { version = "^0.21.1", optional = true } pytest-mock = { version = "^3.14.0", optional = true } sphinx = { version = "^7.1.2", optional = true } sphinx-rtd-theme = { version = "^1.3.0rc1", optional = true } +mypy = { version = "^1.10.1", optional = true } [tool.poetry.extras] postgresql = ["aiopg"] mysql = ["aiomysql", "cryptography"] -develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock"] +develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock", "mypy"] docs = ["aiopg", "aiomysql", "cryptography", "sphinx", "sphinx-rtd-theme"] [build-system] diff --git a/setup.cfg b/setup.cfg index 0eca4f8..02f52c8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,18 @@ [tool:pytest] asyncio_mode = auto + +[mypy] +python_version = 3.10 +ignore_missing_imports = True +no_implicit_optional = True +strict_equality = True +check_untyped_defs = True +warn_redundant_casts = True +warn_unused_configs = True +warn_unused_ignores = True +warn_return_any = True +disallow_any_generics = True +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +exclude = (venv|load-testing|examples) diff --git a/tests/test_common.py b/tests/test_common.py index 0e240f1..257dd40 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -173,26 +173,6 @@ async def get_conn(manager): assert manager.is_connected is False -@pytest.mark.parametrize( - "params, db_cls", - [ - (DB_DEFAULTS[name], db_cls) for name, db_cls in DB_CLASSES.items() - ] - -) -async def test_deferred_init(params, db_cls): - - database = db_cls(None) - assert database.deferred is True - - database.init(**params) - assert database.deferred is False - - TestModel._meta.database = database - TestModel.create_table(True) - TestModel.drop_table(True) - - @pytest.mark.parametrize( "params, db_cls", [