From 276bb4628a5717a6f0ed3981c7b33501843346f4 Mon Sep 17 00:00:00 2001 From: Bryan Forbes Date: Wed, 14 Feb 2024 17:45:20 -0600 Subject: [PATCH] Add typings Add typings to the project and check the project using mypy. `PostgresMessage` and `PoolConnectionProxy` were broken out into their own files to make it easier to add typing via stub (pyi) files. Since they are metaclasses which generate dynamic objects, we can't type them directly in their python module. --- .flake8 | 4 +- .github/workflows/release.yml | 2 +- .gitignore | 3 + .gitmodules | 2 +- MANIFEST.in | 3 +- asyncpg/__init__.py | 5 +- asyncpg/_asyncio_compat.py | 16 +- asyncpg/_version.py | 6 +- asyncpg/cluster.py | 297 +++-- asyncpg/compat.py | 51 +- asyncpg/connect_utils.py | 566 ++++++--- asyncpg/connection.py | 1349 ++++++++++++++++++---- asyncpg/connresource.py | 25 +- asyncpg/cursor.py | 213 +++- asyncpg/exceptions/__init__.py | 525 ++++----- asyncpg/exceptions/_base.py | 222 +--- asyncpg/exceptions/_postgres_message.py | 155 +++ asyncpg/exceptions/_postgres_message.pyi | 36 + asyncpg/introspection.py | 29 +- asyncpg/pgproto | 2 +- asyncpg/pool.py | 616 ++++++---- asyncpg/pool_connection_proxy.py | 91 ++ asyncpg/pool_connection_proxy.pyi | 284 +++++ asyncpg/prepared_stmt.py | 90 +- asyncpg/protocol/__init__.py | 2 + asyncpg/protocol/protocol.pyi | 300 +++++ asyncpg/py.typed | 0 asyncpg/serverversion.py | 24 +- asyncpg/transaction.py | 52 +- asyncpg/types.py | 107 +- asyncpg/utils.py | 20 +- pyproject.toml | 16 +- setup.py | 2 +- tests/test__sourcecode.py | 33 +- tools/generate_exceptions.py | 16 +- 35 files changed, 3849 insertions(+), 1315 deletions(-) create mode 100644 asyncpg/exceptions/_postgres_message.py create mode 100644 asyncpg/exceptions/_postgres_message.pyi create mode 100644 asyncpg/pool_connection_proxy.py create mode 100644 asyncpg/pool_connection_proxy.pyi create mode 100644 asyncpg/protocol/protocol.pyi create mode 100644 asyncpg/py.typed diff --git a/.flake8 b/.flake8 index decf40da..5311c61e 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,5 @@ [flake8] +select = C90,E,F,W,Y0 ignore = E402,E731,W503,W504,E252 -exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv,.tox +exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox +per-file-ignores = *.pyi: F401, F403, F405, F811, E127, E128, E203, E266, E301, E302, E305, E501, E701, E704, E741, B303, W503, W504 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index eef0799e..450f471e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -22,7 +22,7 @@ jobs: github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} version_file: asyncpg/_version.py version_line_pattern: | - __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) + __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - name: Stop if not approved if: steps.checkver.outputs.approved != 'true' diff --git a/.gitignore b/.gitignore index 21286094..a04d0b91 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ docs/_build /.eggs /.vscode /.mypy_cache +/.venv* +/.tox +/.vim diff --git a/.gitmodules b/.gitmodules index c8d0b650..9dc433a1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "asyncpg/pgproto"] path = asyncpg/pgproto - url = https://github.com/MagicStack/py-pgproto.git + url = https://github.com/bryanforbes/py-pgproto.git diff --git a/MANIFEST.in b/MANIFEST.in index 2389f6fa..a51fa57c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ recursive-include docs *.py *.rst Makefile *.css recursive-include examples *.py recursive-include tests *.py *.pem -recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h +recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h include LICENSE README.rst Makefile performance.png .flake8 +include asyncpg/py.typed diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py index e8cd11eb..dff9f58f 100644 --- a/asyncpg/__init__.py +++ b/asyncpg/__init__.py @@ -4,6 +4,7 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations from .connection import connect, Connection # NOQA from .exceptions import * # NOQA @@ -11,9 +12,9 @@ from .protocol import Record # NOQA from .types import * # NOQA - +from . import exceptions from ._version import __version__ # NOQA -__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection') +__all__ = ['connect', 'create_pool', 'Pool', 'Record', 'Connection'] __all__ += exceptions.__all__ # NOQA diff --git a/asyncpg/_asyncio_compat.py b/asyncpg/_asyncio_compat.py index ad7dfd8c..b6f515d7 100644 --- a/asyncpg/_asyncio_compat.py +++ b/asyncpg/_asyncio_compat.py @@ -4,10 +4,15 @@ # # SPDX-License-Identifier: PSF-2.0 +from __future__ import annotations import asyncio import functools import sys +import typing + +if typing.TYPE_CHECKING: + from . import compat if sys.version_info < (3, 11): from async_timeout import timeout as timeout_ctx @@ -15,7 +20,12 @@ from asyncio import timeout as timeout_ctx -async def wait_for(fut, timeout): +_T = typing.TypeVar('_T') + + +async def wait_for( + fut: compat.Awaitable[_T], timeout: float | None +) -> _T: """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. @@ -65,7 +75,7 @@ async def wait_for(fut, timeout): return await fut -async def _cancel_and_wait(fut): +async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None: """Cancel the *fut* future or task and wait until it completes.""" loop = asyncio.get_running_loop() @@ -82,6 +92,6 @@ async def _cancel_and_wait(fut): fut.remove_done_callback(cb) -def _release_waiter(waiter, *args): +def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None: if not waiter.done(): waiter.set_result(None) diff --git a/asyncpg/_version.py b/asyncpg/_version.py index 67fd67ab..383fe4d2 100644 --- a/asyncpg/_version.py +++ b/asyncpg/_version.py @@ -10,4 +10,8 @@ # supported platforms, publish the packages on PyPI, merge the PR # to the target branch, create a Git tag pointing to the commit. -__version__ = '0.30.0.dev0' +from __future__ import annotations + +import typing + +__version__: typing.Final = '0.30.0.dev0' diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py index 4467cc2a..8615d228 100644 --- a/asyncpg/cluster.py +++ b/asyncpg/cluster.py @@ -4,6 +4,7 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio import os @@ -17,28 +18,46 @@ import tempfile import textwrap import time +import typing import asyncpg from asyncpg import serverversion +from asyncpg import exceptions + +if sys.version_info < (3, 12): + from typing_extensions import Unpack +else: + from typing import Unpack + +if typing.TYPE_CHECKING: + import _typeshed + from . import types + from . import connection -_system = platform.uname().system +class _ConnectionSpec(typing.TypedDict): + host: str + port: str + + +_system: typing.Final = platform.uname().system if _system == 'Windows': - def platform_exe(name): + def platform_exe(name: str) -> str: if name.endswith('.exe'): return name return name + '.exe' else: - def platform_exe(name): + def platform_exe(name: str) -> str: return name -def find_available_port(): +def find_available_port() -> int | None: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.bind(('127.0.0.1', 0)) - return sock.getsockname()[1] + sock_name: tuple[str, int] = sock.getsockname() + return sock_name[1] except Exception: return None finally: @@ -50,7 +69,18 @@ class ClusterError(Exception): class Cluster: - def __init__(self, data_dir, *, pg_config_path=None): + _data_dir: str + _pg_config_path: str | None + _pg_bin_dir: str | None + _pg_ctl: str | None + _daemon_pid: int | None + _daemon_process: subprocess.Popen[bytes] | None + _connection_addr: _ConnectionSpec | None + _connection_spec_override: _ConnectionSpec | None + + def __init__( + self, data_dir: str, *, pg_config_path: str | None = None + ) -> None: self._data_dir = data_dir self._pg_config_path = pg_config_path self._pg_bin_dir = ( @@ -63,21 +93,21 @@ def __init__(self, data_dir, *, pg_config_path=None): self._connection_addr = None self._connection_spec_override = None - def get_pg_version(self): + def get_pg_version(self) -> types.ServerVersion: return self._pg_version - def is_managed(self): + def is_managed(self) -> bool: return True - def get_data_dir(self): + def get_data_dir(self) -> str: return self._data_dir - def get_status(self): + def get_status(self) -> str: if self._pg_ctl is None: self._init_env() process = subprocess.run( - [self._pg_ctl, 'status', '-D', self._data_dir], + [typing.cast(str, self._pg_ctl), 'status', '-D', self._data_dir], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.stdout, process.stderr @@ -96,15 +126,24 @@ def get_status(self): return self._test_connection(timeout=0) else: raise ClusterError( - 'pg_ctl status exited with status {:d}: {}'.format( + 'pg_ctl status exited with status {:d}: {!r}'.format( process.returncode, stderr)) - async def connect(self, loop=None, **kwargs): - conn_info = self.get_connection_spec() + async def connect( + self, + loop: asyncio.AbstractEventLoop | None = None, + **kwargs: object + ) -> connection.Connection[typing.Any]: + conn_info = typing.cast( + 'dict[str, typing.Any]', self.get_connection_spec() + ) conn_info.update(kwargs) - return await asyncpg.connect(loop=loop, **conn_info) + return typing.cast( + 'connection.Connection[typing.Any]', + await asyncpg.connect(loop=loop, **conn_info) + ) - def init(self, **settings): + def init(self, **settings: str) -> str: """Initialize cluster.""" if self.get_status() != 'not-initialized': raise ClusterError( @@ -123,8 +162,12 @@ def init(self, **settings): extra_args = [] process = subprocess.run( - [self._pg_ctl, 'init', '-D', self._data_dir] + extra_args, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + [ + typing.cast(str, self._pg_ctl), 'init', '-D', self._data_dir + ] + extra_args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT + ) output = process.stdout @@ -135,7 +178,13 @@ def init(self, **settings): return output.decode() - def start(self, wait=60, *, server_settings={}, **opts): + def start( + self, + wait: int = 60, + *, + server_settings: dict[str, str] = {}, + **opts: object + ) -> None: """Start the cluster.""" status = self.get_status() if status == 'running': @@ -178,17 +227,19 @@ def start(self, wait=60, *, server_settings={}, **opts): for k, v in server_settings.items(): extra_args.extend(['-c', '{}={}'.format(k, v)]) + pg_ctl = typing.cast(str, self._pg_ctl) + if _system == 'Windows': # On Windows we have to use pg_ctl as direct execution # of postgres daemon under an Administrative account # is not permitted and there is no easy way to drop # privileges. if os.getenv('ASYNCPG_DEBUG_SERVER'): - stdout = sys.stdout + stdout: int | typing.TextIO = sys.stdout print( 'asyncpg.cluster: Running', ' '.join([ - self._pg_ctl, 'start', '-D', self._data_dir, + pg_ctl, 'start', '-D', self._data_dir, '-o', ' '.join(extra_args) ]), file=sys.stderr, @@ -197,7 +248,7 @@ def start(self, wait=60, *, server_settings={}, **opts): stdout = subprocess.DEVNULL process = subprocess.run( - [self._pg_ctl, 'start', '-D', self._data_dir, + [pg_ctl, 'start', '-D', self._data_dir, '-o', ' '.join(extra_args)], stdout=stdout, stderr=subprocess.STDOUT) @@ -224,14 +275,14 @@ def start(self, wait=60, *, server_settings={}, **opts): self._test_connection(timeout=wait) - def reload(self): + def reload(self) -> None: """Reload server configuration.""" status = self.get_status() if status != 'running': raise ClusterError('cannot reload: cluster is not running') process = subprocess.run( - [self._pg_ctl, 'reload', '-D', self._data_dir], + [typing.cast(str, self._pg_ctl), 'reload', '-D', self._data_dir], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stderr = process.stderr @@ -241,11 +292,21 @@ def reload(self): 'pg_ctl stop exited with status {:d}: {}'.format( process.returncode, stderr.decode())) - def stop(self, wait=60): + def stop(self, wait: int = 60) -> None: process = subprocess.run( - [self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait), - '-m', 'fast'], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + [ + typing.cast(str, self._pg_ctl), + 'stop', + '-D', + self._data_dir, + '-t', + str(wait), + '-m', + 'fast' + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) stderr = process.stderr @@ -258,14 +319,14 @@ def stop(self, wait=60): self._daemon_process.returncode is None): self._daemon_process.kill() - def destroy(self): + def destroy(self) -> None: status = self.get_status() if status == 'stopped' or status == 'not-initialized': shutil.rmtree(self._data_dir) else: raise ClusterError('cannot destroy {} cluster'.format(status)) - def _get_connection_spec(self): + def _get_connection_spec(self) -> _ConnectionSpec | None: if self._connection_addr is None: self._connection_addr = self._connection_addr_from_pidfile() @@ -277,17 +338,26 @@ def _get_connection_spec(self): else: return self._connection_addr - def get_connection_spec(self): + return None + + def get_connection_spec(self) -> _ConnectionSpec: status = self.get_status() if status != 'running': raise ClusterError('cluster is not running') - return self._get_connection_spec() + spec = self._get_connection_spec() + + if spec is None: + raise ClusterError('cannot determine server connection address') + + return spec - def override_connection_spec(self, **kwargs): - self._connection_spec_override = kwargs + def override_connection_spec(self, **kwargs: str) -> None: + self._connection_spec_override = typing.cast(_ConnectionSpec, kwargs) - def reset_wal(self, *, oid=None, xid=None): + def reset_wal( + self, *, oid: int | None = None, xid: int | None = None + ) -> None: status = self.get_status() if status == 'not-initialized': raise ClusterError( @@ -297,7 +367,7 @@ def reset_wal(self, *, oid=None, xid=None): raise ClusterError( 'cannot modify WAL status: cluster is running') - opts = [] + opts: list[str] = [] if oid is not None: opts.extend(['-o', str(oid)]) if xid is not None: @@ -323,7 +393,7 @@ def reset_wal(self, *, oid=None, xid=None): 'pg_resetwal exited with status {:d}: {}'.format( process.returncode, stderr.decode())) - def reset_hba(self): + def reset_hba(self) -> None: """Remove all records from pg_hba.conf.""" status = self.get_status() if status == 'not-initialized': @@ -339,8 +409,16 @@ def reset_hba(self): raise ClusterError( 'cannot modify HBA records: {}'.format(e)) from e - def add_hba_entry(self, *, type='host', database, user, address=None, - auth_method, auth_options=None): + def add_hba_entry( + self, + *, + type: str = 'host', + database: str, + user: str, + address: str | None = None, + auth_method: str, + auth_options: dict[str, str] | None = None, + ) -> None: """Add a record to pg_hba.conf.""" status = self.get_status() if status == 'not-initialized': @@ -365,7 +443,7 @@ def add_hba_entry(self, *, type='host', database, user, address=None, if auth_options is not None: record += ' ' + ' '.join( - '{}={}'.format(k, v) for k, v in auth_options) + '{}={}'.format(k, v) for k, v in auth_options.items()) try: with open(pg_hba, 'a') as f: @@ -374,7 +452,7 @@ def add_hba_entry(self, *, type='host', database, user, address=None, raise ClusterError( 'cannot modify HBA records: {}'.format(e)) from e - def trust_local_connections(self): + def trust_local_connections(self) -> None: self.reset_hba() if _system != 'Windows': @@ -390,7 +468,7 @@ def trust_local_connections(self): if status == 'running': self.reload() - def trust_local_replication_by(self, user): + def trust_local_replication_by(self, user: str) -> None: if _system != 'Windows': self.add_hba_entry(type='local', database='replication', user=user, auth_method='trust') @@ -404,7 +482,7 @@ def trust_local_replication_by(self, user): if status == 'running': self.reload() - def _init_env(self): + def _init_env(self) -> None: if not self._pg_bin_dir: pg_config = self._find_pg_config(self._pg_config_path) pg_config_data = self._run_pg_config(pg_config) @@ -418,7 +496,7 @@ def _init_env(self): self._postgres = self._find_pg_binary('postgres') self._pg_version = self._get_pg_version() - def _connection_addr_from_pidfile(self): + def _connection_addr_from_pidfile(self) -> _ConnectionSpec | None: pidfile = os.path.join(self._data_dir, 'postmaster.pid') try: @@ -464,7 +542,7 @@ def _connection_addr_from_pidfile(self): 'port': portnum } - def _test_connection(self, timeout=60): + def _test_connection(self, timeout: int = 60) -> str: self._connection_addr = None loop = asyncio.new_event_loop() @@ -478,17 +556,24 @@ def _test_connection(self, timeout=60): continue try: - con = loop.run_until_complete( - asyncpg.connect(database='postgres', - user='postgres', - timeout=5, loop=loop, - **self._connection_addr)) + con: connection.Connection[ + typing.Any + ] = loop.run_until_complete( + asyncpg.connect( + database='postgres', + user='postgres', + timeout=5, loop=loop, + **typing.cast( + _ConnectionSpec, self._connection_addr + ) + ) + ) except (OSError, asyncio.TimeoutError, - asyncpg.CannotConnectNowError, - asyncpg.PostgresConnectionError): + exceptions.CannotConnectNowError, + exceptions.PostgresConnectionError): time.sleep(1) continue - except asyncpg.PostgresError: + except exceptions.PostgresError: # Any other error other than ServerNotReadyError or # ConnectionError is interpreted to indicate the server is # up. @@ -501,16 +586,19 @@ def _test_connection(self, timeout=60): return 'running' - def _run_pg_config(self, pg_config_path): + def _run_pg_config(self, pg_config_path: str) -> dict[str, str]: process = subprocess.run( pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.stdout, process.stderr if process.returncode != 0: - raise ClusterError('pg_config exited with status {:d}: {}'.format( - process.returncode, stderr)) + raise ClusterError( + 'pg_config exited with status {:d}: {!r}'.format( + process.returncode, stderr + ) + ) else: - config = {} + config: dict[str, str] = {} for line in stdout.splitlines(): k, eq, v = line.decode('utf-8').partition('=') @@ -519,7 +607,7 @@ def _run_pg_config(self, pg_config_path): return config - def _find_pg_config(self, pg_config_path): + def _find_pg_config(self, pg_config_path: str | None) -> str: if pg_config_path is None: pg_install = ( os.environ.get('PGINSTALLATION') @@ -529,7 +617,9 @@ def _find_pg_config(self, pg_config_path): pg_config_path = platform_exe( os.path.join(pg_install, 'pg_config')) else: - pathenv = os.environ.get('PATH').split(os.pathsep) + pathenv = typing.cast( + str, os.environ.get('PATH') + ).split(os.pathsep) for path in pathenv: pg_config_path = platform_exe( os.path.join(path, 'pg_config')) @@ -547,8 +637,10 @@ def _find_pg_config(self, pg_config_path): return pg_config_path - def _find_pg_binary(self, binary): - bpath = platform_exe(os.path.join(self._pg_bin_dir, binary)) + def _find_pg_binary(self, binary: str) -> str: + bpath = platform_exe( + os.path.join(typing.cast(str, self._pg_bin_dir), binary) + ) if not os.path.isfile(bpath): raise ClusterError( @@ -557,7 +649,7 @@ def _find_pg_binary(self, binary): return bpath - def _get_pg_version(self): + def _get_pg_version(self) -> types.ServerVersion: process = subprocess.run( [self._postgres, '--version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -565,7 +657,7 @@ def _get_pg_version(self): if process.returncode != 0: raise ClusterError( - 'postgres --version exited with status {:d}: {}'.format( + 'postgres --version exited with status {:d}: {!r}'.format( process.returncode, stderr)) version_string = stdout.decode('utf-8').strip(' \n') @@ -580,9 +672,14 @@ def _get_pg_version(self): class TempCluster(Cluster): - def __init__(self, *, - data_dir_suffix=None, data_dir_prefix=None, - data_dir_parent=None, pg_config_path=None): + def __init__( + self, + *, + data_dir_suffix: str | None = None, + data_dir_prefix: str | None = None, + data_dir_parent: _typeshed.StrPath | None = None, + pg_config_path: str | None = None, + ) -> None: self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix, prefix=data_dir_prefix, dir=data_dir_parent) @@ -590,10 +687,16 @@ def __init__(self, *, class HotStandbyCluster(TempCluster): - def __init__(self, *, - master, replication_user, - data_dir_suffix=None, data_dir_prefix=None, - data_dir_parent=None, pg_config_path=None): + def __init__( + self, + *, + master: _ConnectionSpec, + replication_user: str, + data_dir_suffix: str | None = None, + data_dir_prefix: str | None = None, + data_dir_parent: _typeshed.StrPath | None = None, + pg_config_path: str | None = None, + ) -> None: self._master = master self._repl_user = replication_user super().__init__( @@ -602,11 +705,11 @@ def __init__(self, *, data_dir_parent=data_dir_parent, pg_config_path=pg_config_path) - def _init_env(self): + def _init_env(self) -> None: super()._init_env() self._pg_basebackup = self._find_pg_binary('pg_basebackup') - def init(self, **settings): + def init(self, **settings: str) -> str: """Initialize cluster.""" if self.get_status() != 'not-initialized': raise ClusterError( @@ -641,7 +744,13 @@ def init(self, **settings): return output.decode() - def start(self, wait=60, *, server_settings={}, **opts): + def start( + self, + wait: int = 60, + *, + server_settings: dict[str, str] = {}, + **opts: object + ) -> None: if self._pg_version >= (12, 0): server_settings = server_settings.copy() server_settings['primary_conninfo'] = ( @@ -656,33 +765,43 @@ def start(self, wait=60, *, server_settings={}, **opts): class RunningCluster(Cluster): - def __init__(self, **kwargs): + conn_spec: _ConnectionSpec + + def __init__(self, **kwargs: Unpack[_ConnectionSpec]) -> None: self.conn_spec = kwargs - def is_managed(self): + def is_managed(self) -> bool: return False - def get_connection_spec(self): - return dict(self.conn_spec) + def get_connection_spec(self) -> _ConnectionSpec: + return typing.cast(_ConnectionSpec, dict(self.conn_spec)) - def get_status(self): + def get_status(self) -> str: return 'running' - def init(self, **settings): - pass + def init(self, **settings: str) -> str: # type: ignore[empty-body] + ... - def start(self, wait=60, **settings): - pass + def start(self, wait: int = 60, **settings: object) -> None: + ... - def stop(self, wait=60): - pass + def stop(self, wait: int = 60) -> None: + ... - def destroy(self): - pass + def destroy(self) -> None: + ... - def reset_hba(self): + def reset_hba(self) -> None: raise ClusterError('cannot modify HBA records of unmanaged cluster') - def add_hba_entry(self, *, type='host', database, user, address=None, - auth_method, auth_options=None): + def add_hba_entry( + self, + *, + type: str = 'host', + database: str, + user: str, + address: str | None = None, + auth_method: str, + auth_options: dict[str, str] | None = None, + ) -> None: raise ClusterError('cannot modify HBA records of unmanaged cluster') diff --git a/asyncpg/compat.py b/asyncpg/compat.py index 3eec9eb7..0ff6c6da 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -4,22 +4,26 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import pathlib import platform import typing import sys +if typing.TYPE_CHECKING: + import asyncio -SYSTEM = platform.uname().system +SYSTEM: typing.Final = platform.uname().system -if SYSTEM == 'Windows': + +if sys.platform == 'win32': import ctypes.wintypes CSIDL_APPDATA = 0x001a - def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + def get_pg_home_directory() -> pathlib.Path | None: # We cannot simply use expanduser() as that returns the user's # home directory, whereas Postgres stores its config in # %AppData% on Windows. @@ -31,14 +35,14 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]: return pathlib.Path(buf.value) / 'postgresql' else: - def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + def get_pg_home_directory() -> pathlib.Path | None: try: return pathlib.Path.home() except (RuntimeError, KeyError): return None -async def wait_closed(stream): +async def wait_closed(stream: asyncio.StreamWriter) -> None: # Not all asyncio versions have StreamWriter.wait_closed(). if hasattr(stream, 'wait_closed'): try: @@ -59,3 +63,40 @@ async def wait_closed(stream): from ._asyncio_compat import timeout_ctx as timeout # noqa: F401 else: from asyncio import timeout as timeout # noqa: F401 + +if sys.version_info < (3, 9): + from typing import ( + AsyncIterable as AsyncIterable, + Awaitable as Awaitable, + Callable as Callable, + Coroutine as Coroutine, + Deque as deque, + Generator as Generator, + Iterable as Iterable, + Iterator as Iterator, + List as list, + OrderedDict as OrderedDict, + Sequence as Sequence, + Sized as Sized, + Tuple as tuple, + ) +else: + from builtins import ( # noqa: F401 + list as list, + tuple as tuple, + ) + from collections import ( # noqa: F401 + deque as deque, + OrderedDict as OrderedDict, + ) + from collections.abc import ( # noqa: F401 + AsyncIterable as AsyncIterable, + Awaitable as Awaitable, + Callable as Callable, + Coroutine as Coroutine, + Generator as Generator, + Iterable as Iterable, + Iterator as Iterator, + Sequence as Sequence, + Sized as Sized, + ) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 414231fd..a9789a28 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -4,9 +4,9 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio -import collections import enum import functools import getpass @@ -29,6 +29,58 @@ from . import exceptions from . import protocol +if typing.TYPE_CHECKING: + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + + from . import connection + +_ConnectionT = typing.TypeVar( + '_ConnectionT', + bound='connection.Connection[typing.Any]' +) +_ProtocolT = typing.TypeVar( + '_ProtocolT', + bound='protocol.Protocol[typing.Any]' +) +_AsyncProtocolT = typing.TypeVar( + '_AsyncProtocolT', bound='asyncio.protocols.Protocol' +) +_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record) +_ParsedSSLType = typing.Union[ + ssl_module.SSLContext, typing.Literal[False] +] +_SSLStringValues = typing.Literal[ + 'disable', 'prefer', 'allow', 'require', 'verify-ca', 'verify-full' +] +_TPTupleType = compat.tuple[ + asyncio.WriteTransport, + _AsyncProtocolT +] +AddrType = typing.Union[ + compat.tuple[str, int], + str +] +HostType = typing.Union[compat.list[str], compat.tuple[str, ...], str] +PasswordType = typing.Union[ + str, + compat.Callable[[], str], + compat.Callable[[], compat.Awaitable[str]] +] +PortListType = typing.Union[ + compat.list[typing.Union[int, str]], + compat.list[int], + compat.list[str], +] +PortType = typing.Union[ + PortListType, + int, + str +] +SSLType = typing.Union[_ParsedSSLType, _SSLStringValues, bool] + class SSLMode(enum.IntEnum): disable = 0 @@ -39,48 +91,40 @@ class SSLMode(enum.IntEnum): verify_full = 5 @classmethod - def parse(cls, sslmode): + def parse(cls, sslmode: str | Self) -> Self: if isinstance(sslmode, cls): return sslmode - return getattr(cls, sslmode.replace('-', '_')) - - -_ConnectionParameters = collections.namedtuple( - 'ConnectionParameters', - [ - 'user', - 'password', - 'database', - 'ssl', - 'sslmode', - 'direct_tls', - 'server_settings', - 'target_session_attrs', - ]) - + return typing.cast( + 'Self', + getattr(cls, typing.cast(str, sslmode).replace('-', '_')) + ) -_ClientConfiguration = collections.namedtuple( - 'ConnectionConfiguration', - [ - 'command_timeout', - 'statement_cache_size', - 'max_cached_statement_lifetime', - 'max_cacheable_statement_size', - ]) +class _ConnectionParameters(typing.NamedTuple): + user: str + password: PasswordType | None + database: str + ssl: _ParsedSSLType | None + sslmode: SSLMode | None + direct_tls: bool + server_settings: dict[str, str] | None + target_session_attrs: SessionAttribute -_system = platform.uname().system +class _ClientConfiguration(typing.NamedTuple): + command_timeout: float | None + statement_cache_size: int + max_cached_statement_lifetime: int + max_cacheable_statement_size: int -if _system == 'Windows': - PGPASSFILE = 'pgpass.conf' -else: - PGPASSFILE = '.pgpass' +_system: typing.Final = platform.uname().system +PGPASSFILE: typing.Final = ( + 'pgpass.conf' if _system == 'Windows' else '.pgpass' +) -def _read_password_file(passfile: pathlib.Path) \ - -> typing.List[typing.Tuple[str, ...]]: +def _read_password_file(passfile: pathlib.Path) -> list[tuple[str, ...]]: passtab = [] try: @@ -122,11 +166,13 @@ def _read_password_file(passfile: pathlib.Path) \ def _read_password_from_pgpass( - *, passfile: typing.Optional[pathlib.Path], - hosts: typing.List[str], - ports: typing.List[int], - database: str, - user: str): + *, + passfile: pathlib.Path, + hosts: compat.Iterable[str], + ports: list[int], + database: str, + user: str +) -> str | None: """Parse the pgpass file and return the matching password. :return: @@ -158,7 +204,7 @@ def _read_password_from_pgpass( return None -def _validate_port_spec(hosts, port): +def _validate_port_spec(hosts: compat.Sized, port: PortType) -> list[int]: if isinstance(port, list): # If there is a list of ports, its length must # match that of the host list. @@ -166,42 +212,49 @@ def _validate_port_spec(hosts, port): raise exceptions.ClientConfigurationError( 'could not match {} port numbers to {} hosts'.format( len(port), len(hosts))) + return [int(p) for p in port] else: - port = [port for _ in range(len(hosts))] - - return port + return [int(port) for _ in range(len(hosts))] -def _parse_hostlist(hostlist, port, *, unquote=False): +def _parse_hostlist( + hostlist: str, + port: PortType | None, + *, + unquote: bool = False +) -> tuple[list[str], PortListType]: if ',' in hostlist: # A comma-separated list of host addresses. hostspecs = hostlist.split(',') else: hostspecs = [hostlist] - hosts = [] - hostlist_ports = [] + hosts: list[str] = [] + hostlist_ports: list[int] = [] + ports: list[int] | None = None if not port: portspec = os.environ.get('PGPORT') if portspec: if ',' in portspec: - default_port = [int(p) for p in portspec.split(',')] + temp_port: list[int] | int = [ + int(p) for p in portspec.split(',') + ] else: - default_port = int(portspec) + temp_port = int(portspec) else: - default_port = 5432 + temp_port = 5432 - default_port = _validate_port_spec(hostspecs, default_port) + default_port = _validate_port_spec(hostspecs, temp_port) else: - port = _validate_port_spec(hostspecs, port) + ports = _validate_port_spec(hostspecs, port) for i, hostspec in enumerate(hostspecs): if hostspec[0] == '/': # Unix socket addr = hostspec - hostspec_port = '' + hostspec_port: str = '' elif hostspec[0] == '[': # IPv6 address m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec) @@ -230,13 +283,13 @@ def _parse_hostlist(hostlist, port, *, unquote=False): else: hostlist_ports.append(default_port[i]) - if not port: - port = hostlist_ports + if not ports: + ports = hostlist_ports - return hosts, port + return hosts, ports -def _parse_tls_version(tls_version): +def _parse_tls_version(tls_version: str) -> ssl_module.TLSVersion: if tls_version.startswith('SSL'): raise exceptions.ClientConfigurationError( f"Unsupported TLS version: {tls_version}" @@ -249,7 +302,7 @@ def _parse_tls_version(tls_version): ) -def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: +def _dot_postgresql_path(filename: str) -> pathlib.Path | None: try: homedir = pathlib.Path.home() except (RuntimeError, KeyError): @@ -258,15 +311,34 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: return (homedir / '.postgresql' / filename).resolve() -def _parse_connect_dsn_and_args(*, dsn, host, port, user, - password, passfile, database, ssl, - direct_tls, server_settings, - target_session_attrs): +def _parse_connect_dsn_and_args( + *, + dsn: str | None, + host: HostType | None, + port: PortType | None, + user: str | None, + password: str | None, + passfile: str | None, + database: str | None, + ssl: SSLType | None, + direct_tls: bool, + server_settings: dict[str, str] | None, + target_session_attrs: SessionAttribute | None, +) -> tuple[list[tuple[str, int] | str], _ConnectionParameters]: # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. - auth_hosts = None - sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None + auth_hosts: list[str] | tuple[str, ...] | None = None + sslcert: str | pathlib.Path | None = None + sslkey: str | pathlib.Path | None = None + sslrootcert: str | pathlib.Path | None = None + sslcrl: str | pathlib.Path | None = None + sslpassword = None ssl_min_protocol_version = ssl_max_protocol_version = None + ssl_val: SSLType | str | None = ssl + ssl_parsed: _ParsedSSLType | None = None + target_session_attrs_val: ( + SessionAttribute | str | None + ) = target_session_attrs if dsn: parsed = urllib.parse.urlparse(dsn) @@ -306,10 +378,12 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, password = urllib.parse.unquote(dsn_password) if parsed.query: - query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) - for key, val in query.items(): - if isinstance(val, list): - query[key] = val[-1] + query: dict[str, str] = { + key: val[-1] if isinstance(val, list) else val + for key, val in urllib.parse.parse_qs( + parsed.query, strict_parsing=True + ).items() + } if 'port' in query: val = query.pop('port') @@ -348,8 +422,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if 'sslmode' in query: val = query.pop('sslmode') - if ssl is None: - ssl = val + if ssl_val is None: + ssl_val = val if 'sslcert' in query: sslcert = query.pop('sslcert') @@ -380,8 +454,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, dsn_target_session_attrs = query.pop( 'target_session_attrs' ) - if target_session_attrs is None: - target_session_attrs = dsn_target_session_attrs + if target_session_attrs_val is None: + target_session_attrs_val = dsn_target_session_attrs if query: if server_settings is None: @@ -425,7 +499,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, else: port = int(port) - port = _validate_port_spec(host, port) + validated_ports = _validate_port_spec(host, port) if user is None: user = os.getenv('PGUSER') @@ -456,21 +530,21 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if passfile is None: homedir = compat.get_pg_home_directory() if homedir: - passfile = homedir / PGPASSFILE + passfile_path: pathlib.Path | None = homedir / PGPASSFILE else: - passfile = None + passfile_path = None else: - passfile = pathlib.Path(passfile) + passfile_path = pathlib.Path(passfile) - if passfile is not None: + if passfile_path is not None: password = _read_password_from_pgpass( - hosts=auth_hosts, ports=port, + hosts=auth_hosts, ports=validated_ports, database=database, user=user, - passfile=passfile) + passfile=passfile_path) - addrs = [] + addrs: list[AddrType] = [] have_tcp_addrs = False - for h, p in zip(host, port): + for h, p in zip(host, validated_ports): if h.startswith('/'): # UNIX socket name if '.s.PGSQL.' not in h: @@ -485,15 +559,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, raise exceptions.InternalClientError( 'could not determine the database address to connect to') - if ssl is None: - ssl = os.getenv('PGSSLMODE') + if ssl_val is None: + ssl_val = os.getenv('PGSSLMODE') - if ssl is None and have_tcp_addrs: - ssl = 'prefer' + if ssl_val is None and have_tcp_addrs: + ssl_val = 'prefer' - if isinstance(ssl, (str, SSLMode)): + if isinstance(ssl_val, (str, SSLMode)): try: - sslmode = SSLMode.parse(ssl) + sslmode = SSLMode.parse(ssl_val) except AttributeError: modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) raise exceptions.ClientConfigurationError( @@ -501,23 +575,25 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html if sslmode < SSLMode.allow: - ssl = False + ssl_parsed = False else: - ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) - ssl.check_hostname = sslmode >= SSLMode.verify_full + ssl_parsed = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) + ssl_parsed.check_hostname = sslmode >= SSLMode.verify_full if sslmode < SSLMode.require: - ssl.verify_mode = ssl_module.CERT_NONE + ssl_parsed.verify_mode = ssl_module.CERT_NONE else: if sslrootcert is None: sslrootcert = os.getenv('PGSSLROOTCERT') if sslrootcert: - ssl.load_verify_locations(cafile=sslrootcert) - ssl.verify_mode = ssl_module.CERT_REQUIRED + ssl_parsed.load_verify_locations(cafile=sslrootcert) + ssl_parsed.verify_mode = ssl_module.CERT_REQUIRED else: try: sslrootcert = _dot_postgresql_path('root.crt') if sslrootcert is not None: - ssl.load_verify_locations(cafile=sslrootcert) + ssl_parsed.load_verify_locations( + cafile=sslrootcert + ) else: raise exceptions.ClientConfigurationError( 'cannot determine location of user ' @@ -548,29 +624,31 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, detail=detail, ) elif sslmode == SSLMode.require: - ssl.verify_mode = ssl_module.CERT_NONE + ssl_parsed.verify_mode = ssl_module.CERT_NONE else: assert False, 'unreachable' else: - ssl.verify_mode = ssl_module.CERT_REQUIRED + ssl_parsed.verify_mode = ssl_module.CERT_REQUIRED if sslcrl is None: sslcrl = os.getenv('PGSSLCRL') if sslcrl: - ssl.load_verify_locations(cafile=sslcrl) - ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN + ssl_parsed.load_verify_locations(cafile=sslcrl) + ssl_parsed.verify_flags |= ( + ssl_module.VERIFY_CRL_CHECK_CHAIN + ) else: sslcrl = _dot_postgresql_path('root.crl') if sslcrl is not None: try: - ssl.load_verify_locations(cafile=sslcrl) + ssl_parsed.load_verify_locations(cafile=sslcrl) except ( FileNotFoundError, NotADirectoryError, ): pass else: - ssl.verify_flags |= \ + ssl_parsed.verify_flags |= \ ssl_module.VERIFY_CRL_CHECK_CHAIN if sslkey is None: @@ -584,14 +662,14 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if sslcert is None: sslcert = os.getenv('PGSSLCERT') if sslcert: - ssl.load_cert_chain( + ssl_parsed.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword ) else: sslcert = _dot_postgresql_path('postgresql.crt') if sslcert is not None: try: - ssl.load_cert_chain( + ssl_parsed.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword @@ -603,28 +681,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if hasattr(ssl, 'keylog_filename'): keylogfile = os.environ.get('SSLKEYLOGFILE') if keylogfile and not sys.flags.ignore_environment: - ssl.keylog_filename = keylogfile + ssl_parsed.keylog_filename = keylogfile if ssl_min_protocol_version is None: ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') if ssl_min_protocol_version: - ssl.minimum_version = _parse_tls_version( + ssl_parsed.minimum_version = _parse_tls_version( ssl_min_protocol_version ) else: - ssl.minimum_version = _parse_tls_version('TLSv1.2') + ssl_parsed.minimum_version = _parse_tls_version('TLSv1.2') if ssl_max_protocol_version is None: ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') if ssl_max_protocol_version: - ssl.maximum_version = _parse_tls_version( + ssl_parsed.maximum_version = _parse_tls_version( ssl_max_protocol_version ) - elif ssl is True: - ssl = ssl_module.create_default_context() + elif ssl_val is True: + ssl_parsed = ssl_module.create_default_context() sslmode = SSLMode.verify_full else: + ssl_parsed = ssl_val sslmode = SSLMode.disable if server_settings is not None and ( @@ -635,23 +714,23 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, 'server_settings is expected to be None or ' 'a Dict[str, str]') - if target_session_attrs is None: - target_session_attrs = os.getenv( + if target_session_attrs_val is None: + target_session_attrs_val = os.getenv( "PGTARGETSESSIONATTRS", SessionAttribute.any ) try: - target_session_attrs = SessionAttribute(target_session_attrs) + target_session_attrs = SessionAttribute(target_session_attrs_val) except ValueError: raise exceptions.ClientConfigurationError( "target_session_attrs is expected to be one of " "{!r}" ", got {!r}".format( - SessionAttribute.__members__.values, target_session_attrs + SessionAttribute.__members__.values, target_session_attrs_val ) ) from None params = _ConnectionParameters( - user=user, password=password, database=database, ssl=ssl, + user=user, password=password, database=database, ssl=ssl_parsed, sslmode=sslmode, direct_tls=direct_tls, server_settings=server_settings, target_session_attrs=target_session_attrs) @@ -659,13 +738,26 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, return addrs, params -def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, - database, command_timeout, - statement_cache_size, - max_cached_statement_lifetime, - max_cacheable_statement_size, - ssl, direct_tls, server_settings, - target_session_attrs): +def _parse_connect_arguments( + *, + dsn: str | None, + host: HostType | None, + port: PortType | None, + user: str | None, + password: str | None, + passfile: str | None, + database: str | None, + command_timeout: float | typing.SupportsFloat | None, + statement_cache_size: int, + max_cached_statement_lifetime: int, + max_cacheable_statement_size: int, + ssl: SSLType | None, + direct_tls: bool, + server_settings: dict[str, str] | None, + target_session_attrs: SessionAttribute, +) -> tuple[ + list[tuple[str, int] | str], _ConnectionParameters, _ClientConfiguration +]: local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -706,14 +798,27 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, class TLSUpgradeProto(asyncio.Protocol): - def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): + on_data: asyncio.Future[bool] + host: str + port: int + ssl_context: ssl_module.SSLContext + ssl_is_advisory: bool | None + + def __init__( + self, + loop: asyncio.AbstractEventLoop | None, + host: str, + port: int, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool | None + ) -> None: self.on_data = _create_future(loop) self.host = host self.port = port self.ssl_context = ssl_context self.ssl_is_advisory = ssl_is_advisory - def data_received(self, data): + def data_received(self, data: bytes) -> None: if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and @@ -731,20 +836,63 @@ def data_received(self, data): 'rejected SSL upgrade'.format( host=self.host, port=self.port))) - def connection_lost(self, exc): + def connection_lost(self, exc: Exception | None) -> None: if not self.on_data.done(): if exc is None: exc = ConnectionError('unexpected connection_lost() call') self.on_data.set_exception(exc) -async def _create_ssl_connection(protocol_factory, host, port, *, - loop, ssl_context, ssl_is_advisory=False): - - tr, pr = await loop.create_connection( - lambda: TLSUpgradeProto(loop, host, port, - ssl_context, ssl_is_advisory), - host, port) +@typing.overload +async def _create_ssl_connection( + protocol_factory: compat.Callable[[], _ProtocolT], + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool | None = False +) -> _TPTupleType[_ProtocolT]: + ... + + +@typing.overload +async def _create_ssl_connection( + protocol_factory: compat.Callable[[], '_CancelProto'], + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool | None = False +) -> _TPTupleType['_CancelProto']: + ... + + +async def _create_ssl_connection( + protocol_factory: compat.Callable[ + [], _ProtocolT + ] | compat.Callable[ + [], '_CancelProto' + ], + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: typing.Optional[bool] = False +) -> _TPTupleType[typing.Any]: + + tr, pr = typing.cast( + compat.tuple[asyncio.WriteTransport, TLSUpgradeProto], + await loop.create_connection( + lambda: TLSUpgradeProto( + loop, host, port, ssl_context, ssl_is_advisory + ), + host, + port + ) + ) tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. @@ -757,8 +905,12 @@ async def _create_ssl_connection(protocol_factory, host, port, *, if hasattr(loop, 'start_tls'): if do_ssl_upgrade: try: - new_tr = await loop.start_tls( - tr, pr, ssl_context, server_hostname=host) + new_tr = typing.cast( + asyncio.WriteTransport, + await loop.start_tls( + tr, pr, ssl_context, server_hostname=host + ) + ) except (Exception, asyncio.CancelledError): tr.close() raise @@ -795,13 +947,13 @@ async def _create_ssl_connection(protocol_factory, host, port, *, async def _connect_addr( *, - addr, - loop, - params, - config, - connection_class, - record_class -): + addr: AddrType, + loop: asyncio.AbstractEventLoop, + params: _ConnectionParameters, + config: _ClientConfiguration, + connection_class: type[_ConnectionT], + record_class: type[_RecordT] +) -> _ConnectionT: assert loop is not None params_input = params @@ -810,7 +962,7 @@ async def _connect_addr( if inspect.isawaitable(password): password = await password - params = params._replace(password=password) + params = params._replace(password=typing.cast(str, password)) args = (addr, loop, config, connection_class, record_class, params_input) # prepare the params (which attempt has ssl) for the 2 attempts @@ -838,15 +990,15 @@ class _RetryConnectSignal(Exception): async def __connect_addr( - params, - retry, - addr, - loop, - config, - connection_class, - record_class, - params_input, -): + params: _ConnectionParameters, + retry: bool, + addr: AddrType, + loop: asyncio.AbstractEventLoop, + config: _ClientConfiguration, + connection_class: type[_ConnectionT], + record_class: type[_RecordT], + params_input: _ConnectionParameters, +) -> _ConnectionT: connected = _create_future(loop) proto_factory = lambda: protocol.Protocol( @@ -854,13 +1006,21 @@ async def __connect_addr( if isinstance(addr, str): # UNIX socket - connector = loop.create_unix_connection(proto_factory, addr) + connector = typing.cast( + compat.Coroutine[ + typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]'] + ], + loop.create_unix_connection(proto_factory, addr) + ) elif params.ssl and params.direct_tls: # if ssl and direct_tls are given, skip STARTTLS and perform direct # SSL connection - connector = loop.create_connection( - proto_factory, *addr, ssl=params.ssl + connector = typing.cast( + compat.Coroutine[ + typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]'] + ], + loop.create_connection(proto_factory, *addr, ssl=params.ssl) ) elif params.ssl: @@ -868,7 +1028,12 @@ async def __connect_addr( proto_factory, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: - connector = loop.create_connection(proto_factory, *addr) + connector = typing.cast( + compat.Coroutine[ + typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]'] + ], + loop.create_connection(proto_factory, *addr) + ) tr, pr = await connector @@ -921,18 +1086,24 @@ class SessionAttribute(str, enum.Enum): read_only = "read-only" -def _accept_in_hot_standby(should_be_in_hot_standby: bool): +def _accept_in_hot_standby(should_be_in_hot_standby: bool) -> compat.Callable[ + [connection.Connection[typing.Any]], compat.Awaitable[bool] +]: """ If the server didn't report "in_hot_standby" at startup, we must determine the state by checking "SELECT pg_catalog.pg_is_in_recovery()". If the server allows a connection and states it is in recovery it must be a replica/standby server. """ - async def can_be_used(connection): + async def can_be_used( + connection: connection.Connection[typing.Any] + ) -> bool: settings = connection.get_settings() - hot_standby_status = getattr(settings, 'in_hot_standby', None) + hot_standby_status: str | None = getattr( + settings, 'in_hot_standby', None + ) if hot_standby_status is not None: - is_in_hot_standby = hot_standby_status == 'on' + is_in_hot_standby: bool = hot_standby_status == 'on' else: is_in_hot_standby = await connection.fetchval( "SELECT pg_catalog.pg_is_in_recovery()" @@ -942,11 +1113,15 @@ async def can_be_used(connection): return can_be_used -def _accept_read_only(should_be_read_only: bool): +def _accept_read_only(should_be_read_only: bool) -> compat.Callable[ + [connection.Connection[typing.Any]], compat.Awaitable[bool] +]: """ Verify the server has not set default_transaction_read_only=True """ - async def can_be_used(connection): + async def can_be_used( + connection: connection.Connection[typing.Any] + ) -> bool: settings = connection.get_settings() is_readonly = getattr(settings, 'default_transaction_read_only', 'off') @@ -957,11 +1132,19 @@ async def can_be_used(connection): return can_be_used -async def _accept_any(_): +async def _accept_any(_: connection.Connection[typing.Any]) -> bool: return True -target_attrs_check = { +target_attrs_check: typing.Final[ + dict[ + SessionAttribute, + compat.Callable[ + [connection.Connection[typing.Any]], + compat.Awaitable[bool] + ] + ] +] = { SessionAttribute.any: _accept_any, SessionAttribute.primary: _accept_in_hot_standby(False), SessionAttribute.standby: _accept_in_hot_standby(True), @@ -971,21 +1154,30 @@ async def _accept_any(_): } -async def _can_use_connection(connection, attr: SessionAttribute): +async def _can_use_connection( + connection: connection.Connection[typing.Any], + attr: SessionAttribute +) -> bool: can_use = target_attrs_check[attr] return await can_use(connection) -async def _connect(*, loop, connection_class, record_class, **kwargs): +async def _connect( + *, + loop: asyncio.AbstractEventLoop | None, + connection_class: type[_ConnectionT], + record_class: type[_RecordT], + **kwargs: typing.Any +) -> _ConnectionT: if loop is None: loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(**kwargs) target_attr = params.target_session_attrs - candidates = [] + candidates: list[_ConnectionT] = [] chosen_connection = None - last_error = None + last_error: BaseException | None = None for addr in addrs: try: conn = await _connect_addr( @@ -1020,32 +1212,44 @@ async def _connect(*, loop, connection_class, record_class, **kwargs): ) -async def _cancel(*, loop, addr, params: _ConnectionParameters, - backend_pid, backend_secret): +class _CancelProto(asyncio.Protocol): - class CancelProto(asyncio.Protocol): + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self.on_disconnect = _create_future(loop) + self.is_ssl = False - def __init__(self): - self.on_disconnect = _create_future(loop) - self.is_ssl = False + def connection_lost(self, exc: Exception | None) -> None: + if not self.on_disconnect.done(): + self.on_disconnect.set_result(True) - def connection_lost(self, exc): - if not self.on_disconnect.done(): - self.on_disconnect.set_result(True) + +async def _cancel( + *, + loop: asyncio.AbstractEventLoop, + addr: AddrType, + params: _ConnectionParameters, + backend_pid: int, + backend_secret: str +) -> None: + proto_factory: compat.Callable[ + [], _CancelProto + ] = lambda: _CancelProto(loop) if isinstance(addr, str): - tr, pr = await loop.create_unix_connection(CancelProto, addr) + tr, pr = typing.cast( + _TPTupleType[_CancelProto], + await loop.create_unix_connection(proto_factory, addr) + ) else: if params.ssl and params.sslmode != SSLMode.allow: tr, pr = await _create_ssl_connection( - CancelProto, + proto_factory, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: - tr, pr = await loop.create_connection( - CancelProto, *addr) + tr, pr = await loop.create_connection(proto_factory, *addr) _set_nodelay(_get_socket(tr)) # Pack a CancelRequest message @@ -1058,7 +1262,7 @@ def connection_lost(self, exc): tr.close() -def _get_socket(transport): +def _get_socket(transport: asyncio.BaseTransport) -> typing.Any: sock = transport.get_extra_info('socket') if sock is None: # Shouldn't happen with any asyncio-complaint event loop. @@ -1067,14 +1271,16 @@ def _get_socket(transport): return sock -def _set_nodelay(sock): +def _set_nodelay(sock: typing.Any) -> None: if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) -def _create_future(loop): +def _create_future( + loop: asyncio.AbstractEventLoop | None +) -> asyncio.Future[typing.Any]: try: - create_future = loop.create_future + create_future = loop.create_future # type: ignore[union-attr] except AttributeError: return asyncio.Future(loop=loop) else: diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 0367e365..d551e537 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -4,12 +4,14 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio import asyncpg import collections import collections.abc import contextlib +import dataclasses import functools import itertools import inspect @@ -32,15 +34,114 @@ from . import transaction from . import utils +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + +if typing.TYPE_CHECKING: + import io + + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + + from .protocol import protocol as _cprotocol + from .exceptions import _postgres_message + from . import pool_connection_proxy as _pool + from . import types + +_ConnectionT = typing.TypeVar('_ConnectionT', bound='Connection[typing.Any]') +_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record) +_OtherRecordT = typing.TypeVar('_OtherRecordT', bound=protocol.Record) +_P = ParamSpec('_P') + +_WriterType = compat.Callable[ + [bytes], compat.Coroutine[typing.Any, typing.Any, None] +] +_OutputType = typing.Union[ + 'os.PathLike[typing.Any]', typing.BinaryIO, _WriterType +] +_CopyFormat = typing.Literal['text', 'csv', 'binary'] +_SourceType = typing.Union[ + 'os.PathLike[typing.Any]', typing.BinaryIO, compat.AsyncIterable[bytes] +] +_RecordsType = compat.list[_RecordT] +_RecordsTupleType = compat.tuple[_RecordsType[_RecordT], bytes, bool] + + +class Listener(typing.Protocol): + def __call__( + self, + con_ref: Connection[ + typing.Any + ] | _pool.PoolConnectionProxy[typing.Any], + pid: int, + channel: str, + payload: object, + /, + ) -> compat.Coroutine[typing.Any, typing.Any, None] | None: + ... + + +class LogListener(typing.Protocol): + def __call__( + self, + con_ref: Connection[ + typing.Any + ] | _pool.PoolConnectionProxy[typing.Any], + message: _postgres_message.PostgresMessage, + /, + ) -> compat.Coroutine[typing.Any, typing.Any, None] | None: + ... + + +class TerminationListener(typing.Protocol): + def __call__( + self, + con_ref: Connection[ + typing.Any + ] | _pool.PoolConnectionProxy[typing.Any], + /, + ) -> compat.Coroutine[typing.Any, typing.Any, None] | None: + ... + + +class QueryLogger(typing.Protocol): + def __call__( + self, record: LoggedQuery, / + ) -> compat.Coroutine[typing.Any, typing.Any, None] | None: + ... + + +class Executor(typing.Protocol[_RecordT]): + def __call__( + self, + statement: _cprotocol.PreparedStatementState[_RecordT], + timeout: float | None, + / + ) -> typing.Any: + ... + + +class OnRemove(typing.Protocol[_RecordT]): + def __call__( + self, + statement: _cprotocol.PreparedStatementState[_RecordT], + / + ) -> None: + ... + class ConnectionMeta(type): - def __instancecheck__(cls, instance): + def __instancecheck__(cls, instance: object) -> bool: mro = type(instance).__mro__ return Connection in mro or _ConnectionProxy in mro -class Connection(metaclass=ConnectionMeta): +class Connection(typing.Generic[_RecordT], metaclass=ConnectionMeta): """A representation of a database session. Connections are created by calling :func:`~asyncpg.connection.connect`. @@ -56,10 +157,66 @@ class Connection(metaclass=ConnectionMeta): '_log_listeners', '_termination_listeners', '_cancellations', '_source_traceback', '_query_loggers', '__weakref__') - def __init__(self, protocol, transport, loop, - addr, - config: connect_utils._ClientConfiguration, - params: connect_utils._ConnectionParameters): + _protocol: _cprotocol.BaseProtocol[_RecordT] + _transport: object + _loop: asyncio.AbstractEventLoop + _top_xact: transaction.Transaction | None + _aborted: bool + _pool_release_ctr: int + _stmt_cache: _StatementCache + _stmts_to_close: set[_cprotocol.PreparedStatementState[typing.Any]] + _stmt_cache_enabled: bool + _listeners: dict[ + str, + set[ + _Callback[ + [ + Connection[typing.Any] | + _pool.PoolConnectionProxy[typing.Any], + int, + str, + object + ] + ] + ] + ] + _server_version: types.ServerVersion + _server_caps: ServerCapabilities + _intro_query: str + _reset_query: str | None + _proxy: _pool.PoolConnectionProxy[typing.Any] | None + _stmt_exclusive_section: _Atomic + _config: connect_utils._ClientConfiguration + _params: connect_utils._ConnectionParameters + _addr: connect_utils.AddrType + _log_listeners: set[ + _Callback[ + [ + Connection[typing.Any] | _pool.PoolConnectionProxy[typing.Any], + _postgres_message.PostgresMessage, + ] + ] + ] + _termination_listeners: set[ + _Callback[ + [ + Connection[typing.Any] | _pool.PoolConnectionProxy[typing.Any], + ] + ] + ] + _cancellations: set[asyncio.Task[typing.Any]] + _source_traceback: str | None + _query_loggers: set[_Callback[[LoggedQuery]]] + + def __init__( + self, + protocol: _cprotocol.BaseProtocol[_RecordT], + transport: object, + loop: asyncio.AbstractEventLoop, + addr: tuple[str, int] | str, + config: connect_utils._ClientConfiguration, + params: connect_utils._ConnectionParameters, + ) -> None: self._protocol = protocol self._transport = transport self._loop = loop @@ -120,7 +277,7 @@ def __init__(self, protocol, transport, loop, else: self._source_traceback = None - def __del__(self): + def __del__(self) -> None: if not self.is_closed() and self._protocol is not None: if self._source_traceback: msg = "unclosed connection {!r}; created at:\n {}".format( @@ -136,7 +293,7 @@ def __del__(self): if not self._loop.is_closed(): self.terminate() - async def add_listener(self, channel, callback): + async def add_listener(self, channel: str, callback: Listener) -> None: """Add a listener for Postgres notifications. :param str channel: Channel to listen on. @@ -158,7 +315,7 @@ async def add_listener(self, channel, callback): self._listeners[channel] = set() self._listeners[channel].add(_Callback.from_callable(callback)) - async def remove_listener(self, channel, callback): + async def remove_listener(self, channel: str, callback: Listener) -> None: """Remove a listening callback on the specified channel.""" if self.is_closed(): return @@ -172,7 +329,7 @@ async def remove_listener(self, channel, callback): del self._listeners[channel] await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel))) - def add_log_listener(self, callback): + def add_log_listener(self, callback: LogListener) -> None: """Add a listener for Postgres log messages. It will be called when asyncronous NoticeResponse is received @@ -194,14 +351,14 @@ def add_log_listener(self, callback): raise exceptions.InterfaceError('connection is closed') self._log_listeners.add(_Callback.from_callable(callback)) - def remove_log_listener(self, callback): + def remove_log_listener(self, callback: LogListener) -> None: """Remove a listening callback for log messages. .. versionadded:: 0.12.0 """ self._log_listeners.discard(_Callback.from_callable(callback)) - def add_termination_listener(self, callback): + def add_termination_listener(self, callback: TerminationListener) -> None: """Add a listener that will be called when the connection is closed. :param callable callback: @@ -215,7 +372,9 @@ def add_termination_listener(self, callback): """ self._termination_listeners.add(_Callback.from_callable(callback)) - def remove_termination_listener(self, callback): + def remove_termination_listener( + self, callback: TerminationListener + ) -> None: """Remove a listening callback for connection termination. :param callable callback: @@ -226,7 +385,7 @@ def remove_termination_listener(self, callback): """ self._termination_listeners.discard(_Callback.from_callable(callback)) - def add_query_logger(self, callback): + def add_query_logger(self, callback: QueryLogger) -> None: """Add a logger that will be called when queries are executed. :param callable callback: @@ -239,7 +398,7 @@ def add_query_logger(self, callback): """ self._query_loggers.add(_Callback.from_callable(callback)) - def remove_query_logger(self, callback): + def remove_query_logger(self, callback: QueryLogger) -> None: """Remove a query logger callback. :param callable callback: @@ -250,11 +409,11 @@ def remove_query_logger(self, callback): """ self._query_loggers.discard(_Callback.from_callable(callback)) - def get_server_pid(self): + def get_server_pid(self) -> int: """Return the PID of the Postgres server the connection is bound to.""" return self._protocol.get_server_pid() - def get_server_version(self): + def get_server_version(self) -> types.ServerVersion: """Return the version of the connected PostgreSQL server. The returned value is a named tuple similar to that in @@ -270,15 +429,20 @@ def get_server_version(self): """ return self._server_version - def get_settings(self): + def get_settings(self) -> _cprotocol.ConnectionSettings: """Return connection settings. :return: :class:`~asyncpg.ConnectionSettings`. """ return self._protocol.get_settings() - def transaction(self, *, isolation=None, readonly=False, - deferrable=False): + def transaction( + self, + *, + isolation: transaction.IsolationLevels | None = None, + readonly: bool = False, + deferrable: bool = False, + ) -> transaction.Transaction: """Create a :class:`~transaction.Transaction` object. Refer to `PostgreSQL documentation`_ on the meaning of transaction @@ -303,7 +467,7 @@ def transaction(self, *, isolation=None, readonly=False, self._check_open() return transaction.Transaction(self, isolation, readonly, deferrable) - def is_in_transaction(self): + def is_in_transaction(self) -> bool: """Return True if Connection is currently inside a transaction. :return bool: True if inside transaction, False otherwise. @@ -312,7 +476,9 @@ def is_in_transaction(self): """ return self._protocol.is_in_transaction() - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute( + self, query: str, *args: object, timeout: float | None = None + ) -> str: """Execute an SQL command (or commands). This method can execute many SQL commands at once, when no arguments @@ -359,7 +525,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: ) return status.decode() - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany( + self, + command: str, + args: compat.Iterable[compat.Sequence[object]], + *, + timeout: float | None = None, + ) -> None: """Execute an SQL *command* for each sequence of arguments in *args*. Example: @@ -390,16 +562,42 @@ async def executemany(self, command: str, args, *, timeout: float=None): self._check_open() return await self._executemany(command, args, timeout) + @typing.overload async def _get_statement( self, - query, - timeout, + query: str, + timeout: float | None, *, - named=False, - use_cache=True, - ignore_custom_codec=False, - record_class=None - ): + named: bool | str = ..., + use_cache: bool = ..., + ignore_custom_codec: bool = ..., + record_class: None = ..., + ) -> _cprotocol.PreparedStatementState[_RecordT]: + ... + + @typing.overload + async def _get_statement( + self, + query: str, + timeout: float | None, + *, + named: bool | str = ..., + use_cache: bool = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT], + ) -> _cprotocol.PreparedStatementState[_OtherRecordT]: + ... + + async def _get_statement( + self, + query: str, + timeout: float | None, + *, + named: bool | str = False, + use_cache: bool = True, + ignore_custom_codec: bool = False, + record_class: type[typing.Any] | None = None + ) -> _cprotocol.PreparedStatementState[typing.Any]: if record_class is None: record_class = self._protocol.get_record_class() else: @@ -492,7 +690,11 @@ async def _get_statement( return statement - async def _introspect_types(self, typeoids, timeout): + async def _introspect_types( + self, + typeoids: compat.Iterable[int], + timeout: float | None + ) -> tuple[typing.Any, _cprotocol.PreparedStatementState[_RecordT]]: if self._server_caps.jit: try: cfgrow, _ = await self.__execute( @@ -534,7 +736,7 @@ async def _introspect_types(self, typeoids, timeout): return result - async def _introspect_type(self, typename, schema): + async def _introspect_type(self, typename: str, schema: str) -> typing.Any: if ( schema == 'pg_catalog' and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP @@ -562,14 +764,47 @@ async def _introspect_type(self, typename, schema): return rows[0] + @typing.overload def cursor( self, - query, - *args, - prefetch=None, - timeout=None, - record_class=None - ): + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: None = ..., + ) -> cursor.CursorFactory[_RecordT]: + ... + + @typing.overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> cursor.CursorFactory[_OtherRecordT]: + ... + + @typing.overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> cursor.CursorFactory[_RecordT] | cursor.CursorFactory[_OtherRecordT]: + ... + + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = None, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None, + ) -> cursor.CursorFactory[typing.Any]: """Return a *cursor factory* for the specified query. :param args: @@ -601,13 +836,52 @@ def cursor( record_class, ) + @typing.overload async def prepare( self, - query, + query: str, *, - name=None, - timeout=None, - record_class=None, + name: str | None = ..., + timeout: float | None = ..., + record_class: None = ..., + ) -> prepared_stmt.PreparedStatement[_RecordT]: + ... + + @typing.overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> prepared_stmt.PreparedStatement[_OtherRecordT]: + ... + + @typing.overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] + ): + ... + + async def prepare( + self, + query: str, + *, + name: str | None = None, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None, + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] ): """Create a *prepared statement* for the specified query. @@ -641,32 +915,108 @@ async def prepare( record_class=record_class, ) + @typing.overload + async def _prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + use_cache: bool = ..., + record_class: None = ..., + ) -> prepared_stmt.PreparedStatement[_RecordT]: + ... + + @typing.overload + async def _prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + use_cache: bool = ..., + record_class: type[_OtherRecordT], + ) -> prepared_stmt.PreparedStatement[_OtherRecordT]: + ... + + @typing.overload async def _prepare( self, - query, + query: str, *, - name=None, - timeout=None, - use_cache: bool=False, - record_class=None + name: str | None = ..., + timeout: float | None = ..., + use_cache: bool = ..., + record_class: type[_OtherRecordT] | None, + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] + ): + ... + + async def _prepare( + self, + query: str, + *, + name: str | None = None, + timeout: float | None = None, + use_cache: bool = False, + record_class: type[_OtherRecordT] | None = None + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] ): self._check_open() + + named: bool | str = True if name is None else name stmt = await self._get_statement( query, timeout, - named=True if name is None else name, + named=named, use_cache=use_cache, record_class=record_class, ) - return prepared_stmt.PreparedStatement(self, query, stmt) + return prepared_stmt.PreparedStatement(self, query, typing.cast( + '_cprotocol.PreparedStatementState[typing.Any]', stmt + )) + + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> list[_RecordT]: + ... + @typing.overload async def fetch( self, - query, - *args, - timeout=None, - record_class=None - ) -> list: + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> list[_OtherRecordT]: + ... + + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> list[_RecordT] | list[_OtherRecordT]: + ... + + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None + ) -> list[_RecordT] | list[_OtherRecordT]: """Run a query and return the results as a list of :class:`Record`. :param str query: @@ -696,7 +1046,13 @@ async def fetch( record_class=record_class, ) - async def fetchval(self, query, *args, column=0, timeout=None): + async def fetchval( + self, + query: str, + *args: object, + column: int = 0, + timeout: float | None = None, + ) -> typing.Any: """Run a query and return a value in the first row. :param str query: Query text. @@ -717,13 +1073,43 @@ async def fetchval(self, query, *args, column=0, timeout=None): return None return data[0][column] + @typing.overload async def fetchrow( self, - query, - *args, - timeout=None, - record_class=None - ): + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> _RecordT | None: + ... + + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> _OtherRecordT | None: + ... + + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> _RecordT | _OtherRecordT | None: + ... + + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None + ) -> _RecordT | _OtherRecordT | None: """Run a query and return the first row. :param str query: @@ -757,11 +1143,24 @@ async def fetchrow( return None return data[0] - async def copy_from_table(self, table_name, *, output, - columns=None, schema_name=None, timeout=None, - format=None, oids=None, delimiter=None, - null=None, header=None, quote=None, - escape=None, force_quote=None, encoding=None): + async def copy_from_table( + self, + table_name: str, + *, + output: _OutputType, + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + format: _CopyFormat | None = None, + oids: int | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + ) -> str: """Copy table contents to a file or file-like object. :param str table_name: @@ -829,11 +1228,22 @@ async def copy_from_table(self, table_name, *, output, return await self._copy_out(copy_stmt, output, timeout) - async def copy_from_query(self, query, *args, output, - timeout=None, format=None, oids=None, - delimiter=None, null=None, header=None, - quote=None, escape=None, force_quote=None, - encoding=None): + async def copy_from_query( + self, + query: str, + *args: object, + output: _OutputType, + timeout: float | None = None, + format: _CopyFormat | None = None, + oids: int | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + ) -> str: """Copy the results of a query to a file or file-like object. :param str query: @@ -891,13 +1301,28 @@ async def copy_from_query(self, query, *args, output, return await self._copy_out(copy_stmt, output, timeout) - async def copy_to_table(self, table_name, *, source, - columns=None, schema_name=None, timeout=None, - format=None, oids=None, freeze=None, - delimiter=None, null=None, header=None, - quote=None, escape=None, force_quote=None, - force_not_null=None, force_null=None, - encoding=None, where=None): + async def copy_to_table( + self, + table_name: str, + *, + source: _SourceType, + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + format: _CopyFormat | None = None, + oids: int | None = None, + freeze: bool | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + force_not_null: bool | compat.Iterable[str] | None = None, + force_null: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + where: str | None = None, + ) -> str: """Copy data to the specified table. :param str table_name: @@ -979,9 +1404,18 @@ async def copy_to_table(self, table_name, *, source, return await self._copy_in(copy_stmt, source, timeout) - async def copy_records_to_table(self, table_name, *, records, - columns=None, schema_name=None, - timeout=None, where=None): + async def copy_records_to_table( + self, + table_name: str, + *, + records: compat.Iterable[ + compat.Sequence[object] + ] | compat.AsyncIterable[compat.Sequence[object]], + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + where: str | None = None, + ) -> str: """Copy a list of records to the specified table using binary COPY. :param str table_name: @@ -1081,7 +1515,7 @@ async def copy_records_to_table(self, table_name, *, records, return await self._protocol.copy_in( copy_stmt, None, None, records, intro_ps._state, timeout) - def _format_copy_where(self, where): + def _format_copy_where(self, where: str | None) -> str: if where and not self._server_caps.sql_copy_from_where: raise exceptions.UnsupportedServerFeatureError( 'the `where` parameter requires PostgreSQL 12 or later') @@ -1093,13 +1527,25 @@ def _format_copy_where(self, where): return where_clause - def _format_copy_opts(self, *, format=None, oids=None, freeze=None, - delimiter=None, null=None, header=None, quote=None, - escape=None, force_quote=None, force_not_null=None, - force_null=None, encoding=None): + def _format_copy_opts( + self, + *, + format: _CopyFormat | None = None, + oids: int | None = None, + freeze: bool | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + force_not_null: bool | compat.Iterable[str] | None = None, + force_null: bool | compat.Iterable[str] | None = None, + encoding: str | None = None + ) -> str: kwargs = dict(locals()) kwargs.pop('self') - opts = [] + opts: list[str] = [] if force_quote is not None and isinstance(force_quote, bool): kwargs.pop('force_quote') @@ -1122,24 +1568,31 @@ def _format_copy_opts(self, *, format=None, oids=None, freeze=None, else: return '' - async def _copy_out(self, copy_stmt, output, timeout): + async def _copy_out( + self, copy_stmt: str, output: _OutputType, timeout: float | None + ) -> str: try: - path = os.fspath(output) + path: str | bytes | None = typing.cast( + 'str | bytes', os.fspath(typing.cast(typing.Any, output)) + ) except TypeError: # output is not a path-like object path = None - writer = None + writer: _WriterType | None = None opened_by_us = False run_in_executor = self._loop.run_in_executor if path is not None: # a path - f = await run_in_executor(None, open, path, 'wb') + f = typing.cast( + 'io.BufferedWriter', + await run_in_executor(None, open, path, 'wb') + ) opened_by_us = True elif hasattr(output, 'write'): # file-like - f = output + f = typing.cast('io.BufferedWriter', output) elif callable(output): # assuming calling output returns an awaitable. writer = output @@ -1151,7 +1604,7 @@ async def _copy_out(self, copy_stmt, output, timeout): ) if writer is None: - async def _writer(data): + async def _writer(data: bytes) -> None: await run_in_executor(None, f.write, data) writer = _writer @@ -1161,14 +1614,18 @@ async def _writer(data): if opened_by_us: f.close() - async def _copy_in(self, copy_stmt, source, timeout): + async def _copy_in( + self, copy_stmt: str, source: _SourceType, timeout: float | None + ) -> str: try: - path = os.fspath(source) + path: str | bytes | None = typing.cast( + 'str | bytes', os.fspath(typing.cast(typing.Any, source)) + ) except TypeError: # source is not a path-like object path = None - f = None + f: typing.BinaryIO | None = None reader = None data = None opened_by_us = False @@ -1176,11 +1633,14 @@ async def _copy_in(self, copy_stmt, source, timeout): if path is not None: # a path - f = await run_in_executor(None, open, path, 'rb') + f = typing.cast( + 'io.BufferedWriter', + await run_in_executor(None, open, path, 'rb') + ) opened_by_us = True elif hasattr(source, 'read'): # file-like - f = source + f = typing.cast('io.BufferedWriter', source) elif isinstance(source, collections.abc.AsyncIterable): # assuming calling output returns an awaitable. # copy_in() is designed to handle very large amounts of data, and @@ -1194,11 +1654,13 @@ async def _copy_in(self, copy_stmt, source, timeout): if f is not None: # Copying from a file-like object. class _Reader: - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): - data = await run_in_executor(None, f.read, 524288) + async def __anext__(self) -> bytes: + data = await run_in_executor( + None, typing.cast(typing.BinaryIO, f).read, 524288 + ) if len(data) == 0: raise StopAsyncIteration else: @@ -1211,11 +1673,20 @@ async def __anext__(self): copy_stmt, reader, data, None, None, timeout) finally: if opened_by_us: - await run_in_executor(None, f.close) + await run_in_executor( + None, + typing.cast(typing.BinaryIO, f).close + ) - async def set_type_codec(self, typename, *, - schema='public', encoder, decoder, - format='text'): + async def set_type_codec( + self, + typename: str, + *, + schema: str = 'public', + encoder: compat.Callable[[typing.Any], typing.Any], + decoder: compat.Callable[[typing.Any], typing.Any], + format: str = 'text', + ) -> None: """Set an encoder/decoder pair for the specified data type. :param typename: @@ -1337,7 +1808,7 @@ async def set_type_codec(self, typename, *, self._check_open() settings = self._protocol.get_settings() typeinfo = await self._introspect_type(typename, schema) - full_typeinfos = [] + full_typeinfos: list[object] = [] if introspection.is_scalar_type(typeinfo): kind = 'scalar' elif introspection.is_composite_type(typeinfo): @@ -1375,7 +1846,9 @@ async def set_type_codec(self, typename, *, # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - async def reset_type_codec(self, typename, *, schema='public'): + async def reset_type_codec( + self, typename: str, *, schema: str = 'public' + ) -> None: """Reset *typename* codec to the default implementation. :param typename: @@ -1395,9 +1868,14 @@ async def reset_type_codec(self, typename, *, schema='public'): # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - async def set_builtin_type_codec(self, typename, *, - schema='public', codec_name, - format=None): + async def set_builtin_type_codec( + self, + typename: str, + *, + schema: str = 'public', + codec_name: str, + format: str | None = None, + ) -> None: """Set a builtin codec for the specified scalar data type. This method has two uses. The first is to register a builtin @@ -1445,7 +1923,7 @@ async def set_builtin_type_codec(self, typename, *, # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - def is_closed(self): + def is_closed(self) -> bool: """Return ``True`` if the connection is closed, ``False`` otherwise. :return bool: ``True`` if the connection is closed, ``False`` @@ -1453,7 +1931,7 @@ def is_closed(self): """ return self._aborted or not self._protocol.is_connected() - async def close(self, *, timeout=None): + async def close(self, *, timeout: float | None = None) -> None: """Close the connection gracefully. :param float timeout: @@ -1472,13 +1950,13 @@ async def close(self, *, timeout=None): finally: self._cleanup() - def terminate(self): + def terminate(self) -> None: """Terminate the connection without waiting for pending data.""" if not self.is_closed(): self._abort() self._cleanup() - async def reset(self, *, timeout=None): + async def reset(self, *, timeout: float | None = None) -> None: self._check_open() self._listeners.clear() self._log_listeners.clear() @@ -1499,13 +1977,13 @@ async def reset(self, *, timeout=None): if reset_query: await self.execute(reset_query, timeout=timeout) - def _abort(self): + def _abort(self) -> None: # Put the connection into the aborted state. self._aborted = True self._protocol.abort() - self._protocol = None + self._protocol = None # type: ignore[assignment] - def _cleanup(self): + def _cleanup(self) -> None: self._call_termination_listeners() # Free the resources associated with this connection. # This must be called when a connection is terminated. @@ -1521,7 +1999,7 @@ def _cleanup(self): self._query_loggers.clear() self._clean_tasks() - def _clean_tasks(self): + def _clean_tasks(self) -> None: # Wrap-up any remaining tasks associated with this connection. if self._cancellations: for fut in self._cancellations: @@ -1529,16 +2007,16 @@ def _clean_tasks(self): fut.cancel() self._cancellations.clear() - def _check_open(self): + def _check_open(self) -> None: if self.is_closed(): raise exceptions.InterfaceError('connection is closed') - def _get_unique_id(self, prefix): + def _get_unique_id(self, prefix: str) -> str: global _uid _uid += 1 return '__asyncpg_{}_{:x}__'.format(prefix, _uid) - def _mark_stmts_as_closed(self): + def _mark_stmts_as_closed(self) -> None: for stmt in self._stmt_cache.iter_statements(): stmt.mark_closed() @@ -1548,7 +2026,9 @@ def _mark_stmts_as_closed(self): self._stmt_cache.clear() self._stmts_to_close.clear() - def _maybe_gc_stmt(self, stmt): + def _maybe_gc_stmt( + self, stmt: _cprotocol.PreparedStatementState[typing.Any] + ) -> None: if ( stmt.refs == 0 and stmt.name @@ -1567,7 +2047,7 @@ def _maybe_gc_stmt(self, stmt): stmt.mark_closed() self._stmts_to_close.add(stmt) - async def _cleanup_stmts(self): + async def _cleanup_stmts(self) -> None: # Called whenever we create a new prepared statement in # `Connection._get_statement()` and `_stmts_to_close` is # not empty. @@ -1578,7 +2058,7 @@ async def _cleanup_stmts(self): # so we ignore the timeout. await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT) - async def _cancel(self, waiter): + async def _cancel(self, waiter: asyncio.Future[None]) -> None: try: # Open new connection to the server await connect_utils._cancel( @@ -1602,15 +2082,18 @@ async def _cancel(self, waiter): if not waiter.done(): waiter.set_exception(ex) finally: - self._cancellations.discard( - asyncio.current_task(self._loop)) + current_task = asyncio.current_task(self._loop) + if current_task is not None: + self._cancellations.discard(current_task) if not waiter.done(): waiter.set_result(None) - def _cancel_current_command(self, waiter): + def _cancel_current_command(self, waiter: asyncio.Future[None]) -> None: self._cancellations.add(self._loop.create_task(self._cancel(waiter))) - def _process_log_message(self, fields, last_query): + def _process_log_message( + self, fields: dict[str, str], last_query: str + ) -> None: if not self._log_listeners: return @@ -1618,38 +2101,31 @@ def _process_log_message(self, fields, last_query): con_ref = self._unwrap() for cb in self._log_listeners: - if cb.is_async: - self._loop.create_task(cb.cb(con_ref, message)) - else: - self._loop.call_soon(cb.cb, con_ref, message) + cb.invoke(self._loop, con_ref, message) - def _call_termination_listeners(self): + def _call_termination_listeners(self) -> None: if not self._termination_listeners: return con_ref = self._unwrap() for cb in self._termination_listeners: - if cb.is_async: - self._loop.create_task(cb.cb(con_ref)) - else: - self._loop.call_soon(cb.cb, con_ref) + cb.invoke(self._loop, con_ref) self._termination_listeners.clear() - def _process_notification(self, pid, channel, payload): + def _process_notification( + self, pid: int, channel: str, payload: typing.Any + ) -> None: if channel not in self._listeners: return con_ref = self._unwrap() for cb in self._listeners[channel]: - if cb.is_async: - self._loop.create_task(cb.cb(con_ref, pid, channel, payload)) - else: - self._loop.call_soon(cb.cb, con_ref, pid, channel, payload) + cb.invoke(self._loop, con_ref, pid, channel, payload) - def _unwrap(self): + def _unwrap(self) -> Self | _pool.PoolConnectionProxy[typing.Any]: if self._proxy is None: - con_ref = self + con_ref: Self | _pool.PoolConnectionProxy[typing.Any] = self else: # `_proxy` is not None when the connection is a member # of a connection pool. Which means that the user is working @@ -1658,13 +2134,13 @@ def _unwrap(self): con_ref = self._proxy return con_ref - def _get_reset_query(self): + def _get_reset_query(self) -> str: if self._reset_query is not None: return self._reset_query caps = self._server_caps - _reset_query = [] + _reset_query: list[str] = [] if caps.advisory_locks: _reset_query.append('SELECT pg_advisory_unlock_all();') if caps.sql_close_all: @@ -1674,12 +2150,11 @@ def _get_reset_query(self): if caps.sql_reset: _reset_query.append('RESET ALL;') - _reset_query = '\n'.join(_reset_query) - self._reset_query = _reset_query + self._reset_query = '\n'.join(_reset_query) - return _reset_query + return self._reset_query - def _set_proxy(self, proxy): + def _set_proxy(self, proxy: _pool.PoolConnectionProxy[typing.Any]) -> None: if self._proxy is not None and proxy is not None: # Should not happen unless there is a bug in `Pool`. raise exceptions.InterfaceError( @@ -1687,7 +2162,9 @@ def _set_proxy(self, proxy): self._proxy = proxy - def _check_listeners(self, listeners, listener_type): + def _check_listeners( + self, listeners: compat.Sized, listener_type: str + ) -> None: if listeners: count = len(listeners) @@ -1699,7 +2176,7 @@ def _check_listeners(self, listeners, listener_type): warnings.warn(w) - def _on_release(self, stacklevel=1): + def _on_release(self, stacklevel: int = 1) -> None: # Invalidate external references to the connection. self._pool_release_ctr += 1 # Called when the connection is about to be released to the pool. @@ -1710,10 +2187,10 @@ def _on_release(self, stacklevel=1): self._check_listeners( self._log_listeners, 'log') - def _drop_local_statement_cache(self): + def _drop_local_statement_cache(self) -> None: self._stmt_cache.clear() - def _drop_global_statement_cache(self): + def _drop_global_statement_cache(self) -> None: if self._proxy is not None: # This connection is a member of a pool, so we delegate # the cache drop to the pool. @@ -1722,10 +2199,10 @@ def _drop_global_statement_cache(self): else: self._drop_local_statement_cache() - def _drop_local_type_cache(self): + def _drop_local_type_cache(self) -> None: self._protocol.get_settings().clear_type_cache() - def _drop_global_type_cache(self): + def _drop_global_type_cache(self) -> None: if self._proxy is not None: # This connection is a member of a pool, so we delegate # the cache drop to the pool. @@ -1734,7 +2211,7 @@ def _drop_global_type_cache(self): else: self._drop_local_type_cache() - async def reload_schema_state(self): + async def reload_schema_state(self) -> None: """Indicate that the database schema information must be reloaded. For performance reasons, asyncpg caches certain aspects of the @@ -1779,17 +2256,101 @@ async def reload_schema_state(self): self._drop_global_type_cache() self._drop_global_statement_cache() + @typing.overload async def _execute( self, - query, - args, - limit, - timeout, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, *, - return_status=False, - ignore_custom_codec=False, - record_class=None - ): + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: None = ... + ) -> _RecordsType[_RecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] + ) -> _RecordsType[_OtherRecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] | None + ) -> _RecordsType[_RecordT] | _RecordsType[_OtherRecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: None = ... + ) -> _RecordsTupleType[_RecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] + ) -> _RecordsTupleType[_OtherRecordT]: + ... + + @typing.overload + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] | None + ) -> _RecordsTupleType[_RecordT] | _RecordsTupleType[_OtherRecordT]: + ... + + async def _execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: bool = False, + ignore_custom_codec: bool = False, + record_class: type[_OtherRecordT] | None = None + ) -> _RecordsType[typing.Any] | _RecordsTupleType[typing.Any]: with self._stmt_exclusive_section: result, _ = await self.__execute( query, @@ -1803,7 +2364,7 @@ async def _execute( return result @contextlib.contextmanager - def query_logger(self, callback): + def query_logger(self, callback: QueryLogger) -> compat.Iterator[None]: """Context manager that adds `callback` to the list of query loggers, and removes it upon exit. @@ -1834,7 +2395,9 @@ def __call__(self, record): self.remove_query_logger(callback) @contextlib.contextmanager - def _time_and_log(self, query, args, timeout): + def _time_and_log( + self, query: str, args: typing.Any, timeout: float | None + ) -> compat.Iterator[None]: start = time.monotonic() exception = None try: @@ -1854,23 +2417,127 @@ def _time_and_log(self, query, args, timeout): conn_params=self._params, ) for cb in self._query_loggers: - if cb.is_async: - self._loop.create_task(cb.cb(record)) - else: - self._loop.call_soon(cb.cb, record) + cb.invoke(self._loop, record) + @typing.overload async def __execute( self, - query, - args, - limit, - timeout, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, *, - return_status=False, - ignore_custom_codec=False, - record_class=None - ): - executor = lambda stmt, timeout: self._protocol.bind_execute( + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: None = ... + ) -> tuple[ + _RecordsType[_RecordT], _cprotocol.PreparedStatementState[_RecordT] + ]: + ... + + @typing.overload + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[False] = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] + ) -> tuple[ + _RecordsType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ]: + ... + + @typing.overload + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: None = ... + ) -> tuple[ + _RecordsTupleType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ]: + ... + + @typing.overload + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: typing.Literal[True], + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] + ) -> tuple[ + _RecordsTupleType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ]: + ... + + @typing.overload + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: bool, + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecordT] | None + ) -> tuple[ + _RecordsTupleType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ] | tuple[ + _RecordsType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ] | tuple[ + _RecordsTupleType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ] | tuple[ + _RecordsType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ]: + ... + + async def __execute( + self, + query: str, + args: compat.Sequence[object], + limit: int, + timeout: float | None, + *, + return_status: bool = False, + ignore_custom_codec: bool = False, + record_class: type[_OtherRecordT] | None = None + ) -> tuple[ + _RecordsTupleType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ] | tuple[ + _RecordsType[_RecordT], + _cprotocol.PreparedStatementState[_RecordT] + ] | tuple[ + _RecordsTupleType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ] | tuple[ + _RecordsType[_OtherRecordT], + _cprotocol.PreparedStatementState[_OtherRecordT] + ]: + executor: Executor[ + _OtherRecordT + ] = lambda stmt, timeout: self._protocol.bind_execute( state=stmt, args=args, portal_name='', @@ -1898,8 +2565,15 @@ async def __execute( ) return result, stmt - async def _executemany(self, query, args, timeout): - executor = lambda stmt, timeout: self._protocol.bind_execute_many( + async def _executemany( + self, + query: str, + args: compat.Iterable[compat.Sequence[object]], + timeout: float | None, + ) -> None: + executor: Executor[ + _RecordT + ] = lambda stmt, timeout: self._protocol.bind_execute_many( state=stmt, args=args, portal_name='', @@ -1908,19 +2582,20 @@ async def _executemany(self, query, args, timeout): timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: with self._time_and_log(query, args, timeout): + result: None result, _ = await self._do_execute(query, executor, timeout) return result async def _do_execute( self, - query, - executor, - timeout, - retry=True, + query: str, + executor: Executor[typing.Any], + timeout: float | None, + retry: bool = True, *, - ignore_custom_codec=False, - record_class=None - ): + ignore_custom_codec: bool = False, + record_class: type[_OtherRecordT] | None = None, + ) -> tuple[typing.Any, _cprotocol.PreparedStatementState[typing.Any]]: if timeout is None: stmt = await self._get_statement( query, @@ -1948,7 +2623,7 @@ async def _do_execute( result = await executor(stmt, timeout) finally: after = time.monotonic() - timeout -= after - before + timeout -= after - before # pyright: ignore [reportPossiblyUnboundVariable] # noqa: E501 except exceptions.OutdatedSchemaCacheError: # This exception is raised when we detect a difference between @@ -1992,22 +2667,103 @@ async def _do_execute( return result, stmt -async def connect(dsn=None, *, - host=None, port=None, - user=None, password=None, passfile=None, - database=None, - loop=None, - timeout=60, - statement_cache_size=100, - max_cached_statement_lifetime=300, - max_cacheable_statement_size=1024 * 15, - command_timeout=None, - ssl=None, - direct_tls=False, - connection_class=Connection, - record_class=protocol.Record, - server_settings=None, - target_session_attrs=None): +@typing.overload +async def connect( + dsn: str | None = ..., + *, + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + direct_tls: bool = ..., + record_class: type[_RecordT], + server_settings: dict[str, str] | None = ..., + target_session_attrs: connect_utils.SessionAttribute | None = ..., +) -> Connection[_RecordT]: + ... + + +@typing.overload +async def connect( + dsn: str | None = ..., + *, + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + direct_tls: bool = ..., + connection_class: type[_ConnectionT], + record_class: type[_RecordT] = ..., + server_settings: dict[str, str] | None = ..., + target_session_attrs: connect_utils.SessionAttribute | None = ..., +) -> _ConnectionT: + ... + + +@typing.overload +async def connect( + dsn: str | None = ..., + *, + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + direct_tls: bool = ..., + server_settings: dict[str, str] | None = ..., + target_session_attrs: connect_utils.SessionAttribute | None = ..., +) -> Connection[protocol.Record]: + ... + + +async def connect( + dsn: str | None = None, + *, + host: connect_utils.HostType | None = None, + port: connect_utils.PortType | None = None, + user: str | None = None, + password: connect_utils.PasswordType | None = None, + passfile: str | None = None, + database: str | None = None, + loop: asyncio.AbstractEventLoop | None = None, + timeout: float = 60, + statement_cache_size: int = 100, + max_cached_statement_lifetime: int = 300, + max_cacheable_statement_size: int = 1024 * 15, + command_timeout: float | None = None, + ssl: connect_utils.SSLType | None = None, + direct_tls: bool = False, + connection_class: type[_ConnectionT] = typing.cast(typing.Any, Connection), + record_class: type[_RecordT] = typing.cast(typing.Any, protocol.Record), + server_settings: dict[str, str] | None = None, + target_session_attrs: connect_utils.SessionAttribute | None = None, +) -> Connection[typing.Any]: r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2348,11 +3104,24 @@ async def connect(dsn=None, *, ) -class _StatementCacheEntry: +_StatementCacheKey = compat.tuple[str, 'type[_RecordT]', bool] + + +class _StatementCacheEntry(typing.Generic[_RecordT]): __slots__ = ('_query', '_statement', '_cache', '_cleanup_cb') - def __init__(self, cache, query, statement): + _query: _StatementCacheKey[_RecordT] + _statement: _cprotocol.PreparedStatementState[_RecordT] + _cache: _StatementCache + _cleanup_cb: asyncio.TimerHandle | None + + def __init__( + self, + cache: _StatementCache, + query: _StatementCacheKey[_RecordT], + statement: _cprotocol.PreparedStatementState[_RecordT] + ) -> None: self._cache = cache self._query = query self._statement = statement @@ -2364,7 +3133,23 @@ class _StatementCache: __slots__ = ('_loop', '_entries', '_max_size', '_on_remove', '_max_lifetime') - def __init__(self, *, loop, max_size, on_remove, max_lifetime): + _loop: asyncio.AbstractEventLoop + _entries: compat.OrderedDict[ + _StatementCacheKey[typing.Any], + _StatementCacheEntry[typing.Any] + ] + _max_size: int + _on_remove: OnRemove[typing.Any] + _max_lifetime: float + + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop, + max_size: int, + on_remove: OnRemove[typing.Any], + max_lifetime: float + ) -> None: self._loop = loop self._max_size = max_size self._on_remove = on_remove @@ -2389,21 +3174,21 @@ def __init__(self, *, loop, max_size, on_remove, max_lifetime): # beginning of it. self._entries = collections.OrderedDict() - def __len__(self): + def __len__(self) -> int: return len(self._entries) - def get_max_size(self): + def get_max_size(self) -> int: return self._max_size - def set_max_size(self, new_size): + def set_max_size(self, new_size: int) -> None: assert new_size >= 0 self._max_size = new_size self._maybe_cleanup() - def get_max_lifetime(self): + def get_max_lifetime(self) -> float: return self._max_lifetime - def set_max_lifetime(self, new_lifetime): + def set_max_lifetime(self, new_lifetime: float) -> None: assert new_lifetime >= 0 self._max_lifetime = new_lifetime for entry in self._entries.values(): @@ -2411,14 +3196,16 @@ def set_max_lifetime(self, new_lifetime): # and setup a new one if necessary. self._set_entry_timeout(entry) - def get(self, query, *, promote=True): + def get( + self, query: _StatementCacheKey[_RecordT], *, promote: bool = True + ) -> _cprotocol.PreparedStatementState[_RecordT] | None: if not self._max_size: # The cache is disabled. - return + return None - entry = self._entries.get(query) # type: _StatementCacheEntry + entry: _StatementCacheEntry[_RecordT] | None = self._entries.get(query) if entry is None: - return + return None if entry._statement.closed: # Happens in unittests when we call `stmt._state.mark_closed()` @@ -2426,7 +3213,7 @@ def get(self, query, *, promote=True): # cache error. self._entries.pop(query) self._clear_entry_callback(entry) - return + return None if promote: # `promote` is `False` when `get()` is called by `has()`. @@ -2434,10 +3221,14 @@ def get(self, query, *, promote=True): return entry._statement - def has(self, query): + def has(self, query: _StatementCacheKey[_RecordT]) -> bool: return self.get(query, promote=False) is not None - def put(self, query, statement): + def put( + self, + query: _StatementCacheKey[_RecordT], + statement: _cprotocol.PreparedStatementState[_RecordT], + ) -> None: if not self._max_size: # The cache is disabled. return @@ -2448,10 +3239,12 @@ def put(self, query, statement): # if necessary. self._maybe_cleanup() - def iter_statements(self): + def iter_statements( + self + ) -> compat.Iterator[_cprotocol.PreparedStatementState[typing.Any]]: return (e._statement for e in self._entries.values()) - def clear(self): + def clear(self) -> None: # Store entries for later. entries = tuple(self._entries.values()) @@ -2464,7 +3257,9 @@ def clear(self): self._clear_entry_callback(entry) self._on_remove(entry._statement) - def _set_entry_timeout(self, entry): + def _set_entry_timeout( + self, entry: _StatementCacheEntry[typing.Any] + ) -> None: # Clear the existing timeout. self._clear_entry_callback(entry) @@ -2473,23 +3268,31 @@ def _set_entry_timeout(self, entry): entry._cleanup_cb = self._loop.call_later( self._max_lifetime, self._on_entry_expired, entry) - def _new_entry(self, query, statement): + def _new_entry( + self, + query: _StatementCacheKey[_RecordT], + statement: _cprotocol.PreparedStatementState[_RecordT], + ) -> _StatementCacheEntry[_RecordT]: entry = _StatementCacheEntry(self, query, statement) self._set_entry_timeout(entry) return entry - def _on_entry_expired(self, entry): + def _on_entry_expired( + self, entry: _StatementCacheEntry[typing.Any] + ) -> None: # `call_later` callback, called when an entry stayed longer # than `self._max_lifetime`. if self._entries.get(entry._query) is entry: self._entries.pop(entry._query) self._on_remove(entry._statement) - def _clear_entry_callback(self, entry): + def _clear_entry_callback( + self, entry: _StatementCacheEntry[typing.Any] + ) -> None: if entry._cleanup_cb is not None: entry._cleanup_cb.cancel() - def _maybe_cleanup(self): + def _maybe_cleanup(self) -> None: # Delete cache entries until the size of the cache is `max_size`. while len(self._entries) > self._max_size: old_query, old_entry = self._entries.popitem(last=False) @@ -2500,13 +3303,35 @@ def _maybe_cleanup(self): self._on_remove(old_entry._statement) -class _Callback(typing.NamedTuple): +_CallbackType = compat.Callable[ + _P, + 'compat.Coroutine[typing.Any, typing.Any, None] | None' +] + + +@dataclasses.dataclass(frozen=True) +class _Callback(typing.Generic[_P]): + __slots__ = ('cb', 'is_async') - cb: typing.Callable[..., None] + cb: _CallbackType[_P] is_async: bool + def invoke( + self, + loop: asyncio.AbstractEventLoop, + /, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> None: + if self.is_async: + loop.create_task( + typing.cast(typing.Any, self.cb(*args, **kwargs)) + ) + else: + loop.call_soon(lambda: self.cb(*args, **kwargs)) + @classmethod - def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': + def from_callable(cls, cb: _CallbackType[_P]) -> Self: if inspect.iscoroutinefunction(cb): is_async = True elif callable(cb): @@ -2523,39 +3348,52 @@ def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': class _Atomic: __slots__ = ('_acquired',) - def __init__(self): + _acquired: int + + def __init__(self) -> None: self._acquired = 0 - def __enter__(self): + def __enter__(self) -> None: if self._acquired: raise exceptions.InterfaceError( 'cannot perform operation: another operation is in progress') self._acquired = 1 - def __exit__(self, t, e, tb): + def __exit__(self, t: object, e: object, tb: object) -> None: self._acquired = 0 -class _ConnectionProxy: +class _ConnectionProxy(typing.Generic[_RecordT]): # Base class to enable `isinstance(Connection)` check. __slots__ = () -LoggedQuery = collections.namedtuple( - 'LoggedQuery', - ['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr', - 'conn_params']) -LoggedQuery.__doc__ = 'Log record of an executed query.' - - -ServerCapabilities = collections.namedtuple( - 'ServerCapabilities', - ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', - 'sql_close_all', 'sql_copy_from_where', 'jit']) -ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' - - -def _detect_server_capabilities(server_version, connection_settings): +class LoggedQuery(typing.NamedTuple): + '''Log record of an executed query.''' + query: str + args: typing.Any + timeout: float | None + elapsed: float + exception: BaseException | None + conn_addr: tuple[str, int] | str + conn_params: connect_utils._ConnectionParameters + + +class ServerCapabilities(typing.NamedTuple): + '''PostgreSQL server capabilities.''' + advisory_locks: bool + notifications: bool + plpgsql: bool + sql_reset: bool + sql_close_all: bool + sql_copy_from_where: bool + jit: bool + + +def _detect_server_capabilities( + server_version: types.ServerVersion, + connection_settings: _cprotocol.ConnectionSettings, +) -> ServerCapabilities: if hasattr(connection_settings, 'padb_revision'): # Amazon Redshift detected. advisory_locks = False @@ -2604,18 +3442,18 @@ def _detect_server_capabilities(server_version, connection_settings): ) -def _extract_stack(limit=10): +def _extract_stack(limit: int = 10) -> str: """Replacement for traceback.extract_stack() that only does the necessary work for asyncio debug mode. """ frame = sys._getframe().f_back try: - stack = traceback.StackSummary.extract( + stack: list[traceback.FrameSummary] = traceback.StackSummary.extract( traceback.walk_stack(frame), lookup_lines=False) finally: del frame - apg_path = asyncpg.__path__[0] + apg_path = list(asyncpg.__path__)[0] i = 0 while i < len(stack) and stack[i][0].startswith(apg_path): i += 1 @@ -2625,7 +3463,7 @@ def _extract_stack(limit=10): return ''.join(traceback.format_list(stack)) -def _check_record_class(record_class): +def _check_record_class(record_class: type[typing.Any]) -> None: if record_class is protocol.Record: pass elif ( @@ -2646,7 +3484,10 @@ def _check_record_class(record_class): ) -def _weak_maybe_gc_stmt(weak_ref, stmt): +def _weak_maybe_gc_stmt( + weak_ref: weakref.ref[Connection[typing.Any]], + stmt: _cprotocol.PreparedStatementState[typing.Any], +) -> None: self = weak_ref() if self is not None: self._maybe_gc_stmt(stmt) diff --git a/asyncpg/connresource.py b/asyncpg/connresource.py index 3b0c1d3c..60aa97a6 100644 --- a/asyncpg/connresource.py +++ b/asyncpg/connresource.py @@ -5,31 +5,46 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import functools +import typing from . import exceptions +if typing.TYPE_CHECKING: + from . import compat + from . import connection as _conn -def guarded(meth): +_F = typing.TypeVar('_F', bound='compat.Callable[..., typing.Any]') + + +def guarded(meth: _F) -> _F: """A decorator to add a sanity check to ConnectionResource methods.""" @functools.wraps(meth) - def _check(self, *args, **kwargs): + def _check( + self: ConnectionResource, + *args: typing.Any, + **kwargs: typing.Any + ) -> typing.Any: self._check_conn_validity(meth.__name__) return meth(self, *args, **kwargs) - return _check + return typing.cast(_F, _check) class ConnectionResource: __slots__ = ('_connection', '_con_release_ctr') - def __init__(self, connection): + _connection: _conn.Connection[typing.Any] + _con_release_ctr: int + + def __init__(self, connection: _conn.Connection[typing.Any]) -> None: self._connection = connection self._con_release_ctr = connection._pool_release_ctr - def _check_conn_validity(self, meth_name): + def _check_conn_validity(self, meth_name: str) -> None: con_release_ctr = self._connection._pool_release_ctr if con_release_ctr != self._con_release_ctr: raise exceptions.InterfaceError( diff --git a/asyncpg/cursor.py b/asyncpg/cursor.py index b4abeed1..0b3980ba 100644 --- a/asyncpg/cursor.py +++ b/asyncpg/cursor.py @@ -4,14 +4,30 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import collections +import typing from . import connresource from . import exceptions +if typing.TYPE_CHECKING: + import sys -class CursorFactory(connresource.ConnectionResource): + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + + from .protocol import protocol as _cprotocol + from . import connection as _connection + from . import compat + +_RecordT = typing.TypeVar('_RecordT', bound='_cprotocol.Record') + + +class CursorFactory(connresource.ConnectionResource, typing.Generic[_RecordT]): """A cursor interface for the results of a query. A cursor interface can be used to initiate efficient traversal of the @@ -27,16 +43,49 @@ class CursorFactory(connresource.ConnectionResource): '_record_class', ) + _state: _cprotocol.PreparedStatementState[_RecordT] | None + _args: compat.Sequence[object] + _prefetch: int | None + _query: str + _timeout: float | None + _record_class: type[_RecordT] | None + + @typing.overload + def __init__( + self: CursorFactory[_RecordT], + connection: _connection.Connection[_RecordT], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + prefetch: int | None, + timeout: float | None, + record_class: None + ) -> None: + ... + + @typing.overload + def __init__( + self: CursorFactory[_RecordT], + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + prefetch: int | None, + timeout: float | None, + record_class: type[_RecordT] + ) -> None: + ... + def __init__( self, - connection, - query, - state, - args, - prefetch, - timeout, - record_class - ): + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + prefetch: int | None, + timeout: float | None, + record_class: type[_RecordT] | None + ) -> None: super().__init__(connection) self._args = args self._prefetch = prefetch @@ -48,7 +97,7 @@ def __init__( state.attach() @connresource.guarded - def __aiter__(self): + def __aiter__(self) -> CursorIterator[_RecordT]: prefetch = 50 if self._prefetch is None else self._prefetch return CursorIterator( self._connection, @@ -61,11 +110,13 @@ def __aiter__(self): ) @connresource.guarded - def __await__(self): + def __await__( + self + ) -> compat.Generator[typing.Any, None, Cursor[_RecordT]]: if self._prefetch is not None: raise exceptions.InterfaceError( 'prefetch argument can only be specified for iterable cursor') - cursor = Cursor( + cursor: Cursor[_RecordT] = Cursor( self._connection, self._query, self._state, @@ -74,13 +125,13 @@ def __await__(self): ) return cursor._init(self._timeout).__await__() - def __del__(self): + def __del__(self) -> None: if self._state is not None: self._state.detach() self._connection._maybe_gc_stmt(self._state) -class BaseCursor(connresource.ConnectionResource): +class BaseCursor(connresource.ConnectionResource, typing.Generic[_RecordT]): __slots__ = ( '_state', @@ -91,7 +142,43 @@ class BaseCursor(connresource.ConnectionResource): '_record_class', ) - def __init__(self, connection, query, state, args, record_class): + _state: _cprotocol.PreparedStatementState[_RecordT] | None + _args: compat.Sequence[object] + _portal_name: str | None + _exhausted: bool + _query: str + _record_class: type[_RecordT] | None + + @typing.overload + def __init__( + self: BaseCursor[_RecordT], + connection: _connection.Connection[_RecordT], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: None, + ) -> None: + ... + + @typing.overload + def __init__( + self: BaseCursor[_RecordT], + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: type[_RecordT], + ) -> None: + ... + + def __init__( + self, + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: type[_RecordT] | None, + ) -> None: super().__init__(connection) self._args = args self._state = state @@ -102,7 +189,7 @@ def __init__(self, connection, query, state, args, record_class): self._query = query self._record_class = record_class - def _check_ready(self): + def _check_ready(self) -> None: if self._state is None: raise exceptions.InterfaceError( 'cursor: no associated prepared statement') @@ -115,7 +202,7 @@ def _check_ready(self): raise exceptions.NoActiveSQLTransactionError( 'cursor cannot be created outside of a transaction') - async def _bind_exec(self, n, timeout): + async def _bind_exec(self, n: int, timeout: float | None) -> typing.Any: self._check_ready() if self._portal_name: @@ -126,11 +213,15 @@ async def _bind_exec(self, n, timeout): protocol = con._protocol self._portal_name = con._get_unique_id('portal') + + if typing.TYPE_CHECKING: + assert self._state is not None + buffer, _, self._exhausted = await protocol.bind_execute( self._state, self._args, self._portal_name, n, True, timeout) return buffer - async def _bind(self, timeout): + async def _bind(self, timeout: float | None) -> typing.Any: self._check_ready() if self._portal_name: @@ -141,12 +232,16 @@ async def _bind(self, timeout): protocol = con._protocol self._portal_name = con._get_unique_id('portal') + + if typing.TYPE_CHECKING: + assert self._state is not None + buffer = await protocol.bind(self._state, self._args, self._portal_name, timeout) return buffer - async def _exec(self, n, timeout): + async def _exec(self, n: int, timeout: float | None) -> typing.Any: self._check_ready() if not self._portal_name: @@ -158,7 +253,7 @@ async def _exec(self, n, timeout): self._state, self._portal_name, n, True, timeout) return buffer - async def _close_portal(self, timeout): + async def _close_portal(self, timeout: float | None) -> None: self._check_ready() if not self._portal_name: @@ -169,8 +264,8 @@ async def _close_portal(self, timeout): await protocol.close_portal(self._portal_name, timeout) self._portal_name = None - def __repr__(self): - attrs = [] + def __repr__(self) -> str: + attrs: list[str] = [] if self._exhausted: attrs.append('exhausted') attrs.append('') # to separate from id @@ -182,29 +277,59 @@ def __repr__(self): return '<{}.{} "{!s:.30}" {}{:#x}>'.format( mod, self.__class__.__name__, - self._state.query, + self._state.query if self._state is not None else '', ' '.join(attrs), id(self)) - def __del__(self): + def __del__(self) -> None: if self._state is not None: self._state.detach() self._connection._maybe_gc_stmt(self._state) -class CursorIterator(BaseCursor): +class CursorIterator(BaseCursor[_RecordT]): __slots__ = ('_buffer', '_prefetch', '_timeout') + _buffer: compat.deque[_RecordT] + _prefetch: int + _timeout: float | None + + @typing.overload + def __init__( + self: CursorIterator[_RecordT], + connection: _connection.Connection[_RecordT], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: None, + prefetch: int, + timeout: float | None, + ) -> None: + ... + + @typing.overload + def __init__( + self: CursorIterator[_RecordT], + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: type[_RecordT], + prefetch: int, + timeout: float | None, + ) -> None: + ... + def __init__( self, - connection, - query, - state, - args, - record_class, - prefetch, - timeout - ): + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT] | None, + args: compat.Sequence[object], + record_class: type[_RecordT] | None, + prefetch: int, + timeout: float | None, + ) -> None: super().__init__(connection, query, state, args, record_class) if prefetch <= 0: @@ -216,11 +341,11 @@ def __init__( self._timeout = timeout @connresource.guarded - def __aiter__(self): + def __aiter__(self) -> Self: return self @connresource.guarded - async def __anext__(self): + async def __anext__(self) -> _RecordT: if self._state is None: self._state = await self._connection._get_statement( self._query, @@ -247,12 +372,12 @@ async def __anext__(self): raise StopAsyncIteration -class Cursor(BaseCursor): +class Cursor(BaseCursor[_RecordT]): """An open *portal* into the results of a query.""" __slots__ = () - async def _init(self, timeout): + async def _init(self, timeout: float | None) -> Self: if self._state is None: self._state = await self._connection._get_statement( self._query, @@ -266,7 +391,9 @@ async def _init(self, timeout): return self @connresource.guarded - async def fetch(self, n, *, timeout=None): + async def fetch( + self, n: int, *, timeout: float | None = None + ) -> list[_RecordT]: r"""Return the next *n* rows as a list of :class:`Record` objects. :param float timeout: Optional timeout value in seconds. @@ -278,13 +405,15 @@ async def fetch(self, n, *, timeout=None): raise exceptions.InterfaceError('n must be greater than zero') if self._exhausted: return [] - recs = await self._exec(n, timeout) + recs: list[_RecordT] = await self._exec(n, timeout) if len(recs) < n: self._exhausted = True return recs @connresource.guarded - async def fetchrow(self, *, timeout=None): + async def fetchrow( + self, *, timeout: float | None = None + ) -> _RecordT | None: r"""Return the next row. :param float timeout: Optional timeout value in seconds. @@ -294,14 +423,14 @@ async def fetchrow(self, *, timeout=None): self._check_ready() if self._exhausted: return None - recs = await self._exec(1, timeout) + recs: list[_RecordT] = await self._exec(1, timeout) if len(recs) < 1: self._exhausted = True return None return recs[0] @connresource.guarded - async def forward(self, n, *, timeout=None) -> int: + async def forward(self, n: int, *, timeout: float | None = None) -> int: r"""Skip over the next *n* rows. :param float timeout: Optional timeout value in seconds. diff --git a/asyncpg/exceptions/__init__.py b/asyncpg/exceptions/__init__.py index 8c97d5a0..4769f766 100644 --- a/asyncpg/exceptions/__init__.py +++ b/asyncpg/exceptions/__init__.py @@ -1,88 +1,91 @@ # GENERATED FROM postgresql/src/backend/utils/errcodes.txt # DO NOT MODIFY, use tools/generate_exceptions.py to update +from __future__ import annotations + +import typing from ._base import * # NOQA from . import _base class PostgresWarning(_base.PostgresLogMessage, Warning): - sqlstate = '01000' + sqlstate: typing.ClassVar[str] = '01000' class DynamicResultSetsReturned(PostgresWarning): - sqlstate = '0100C' + sqlstate: typing.ClassVar[str] = '0100C' class ImplicitZeroBitPadding(PostgresWarning): - sqlstate = '01008' + sqlstate: typing.ClassVar[str] = '01008' class NullValueEliminatedInSetFunction(PostgresWarning): - sqlstate = '01003' + sqlstate: typing.ClassVar[str] = '01003' class PrivilegeNotGranted(PostgresWarning): - sqlstate = '01007' + sqlstate: typing.ClassVar[str] = '01007' class PrivilegeNotRevoked(PostgresWarning): - sqlstate = '01006' + sqlstate: typing.ClassVar[str] = '01006' class StringDataRightTruncation(PostgresWarning): - sqlstate = '01004' + sqlstate: typing.ClassVar[str] = '01004' class DeprecatedFeature(PostgresWarning): - sqlstate = '01P01' + sqlstate: typing.ClassVar[str] = '01P01' class NoData(PostgresWarning): - sqlstate = '02000' + sqlstate: typing.ClassVar[str] = '02000' class NoAdditionalDynamicResultSetsReturned(NoData): - sqlstate = '02001' + sqlstate: typing.ClassVar[str] = '02001' class SQLStatementNotYetCompleteError(_base.PostgresError): - sqlstate = '03000' + sqlstate: typing.ClassVar[str] = '03000' class PostgresConnectionError(_base.PostgresError): - sqlstate = '08000' + sqlstate: typing.ClassVar[str] = '08000' class ConnectionDoesNotExistError(PostgresConnectionError): - sqlstate = '08003' + sqlstate: typing.ClassVar[str] = '08003' class ConnectionFailureError(PostgresConnectionError): - sqlstate = '08006' + sqlstate: typing.ClassVar[str] = '08006' class ClientCannotConnectError(PostgresConnectionError): - sqlstate = '08001' + sqlstate: typing.ClassVar[str] = '08001' class ConnectionRejectionError(PostgresConnectionError): - sqlstate = '08004' + sqlstate: typing.ClassVar[str] = '08004' class TransactionResolutionUnknownError(PostgresConnectionError): - sqlstate = '08007' + sqlstate: typing.ClassVar[str] = '08007' class ProtocolViolationError(PostgresConnectionError): - sqlstate = '08P01' + sqlstate: typing.ClassVar[str] = '08P01' class TriggeredActionError(_base.PostgresError): - sqlstate = '09000' + sqlstate: typing.ClassVar[str] = '09000' class FeatureNotSupportedError(_base.PostgresError): - sqlstate = '0A000' + sqlstate: typing.ClassVar[str] = '0A000' class InvalidCachedStatementError(FeatureNotSupportedError): @@ -90,969 +93,969 @@ class InvalidCachedStatementError(FeatureNotSupportedError): class InvalidTransactionInitiationError(_base.PostgresError): - sqlstate = '0B000' + sqlstate: typing.ClassVar[str] = '0B000' class LocatorError(_base.PostgresError): - sqlstate = '0F000' + sqlstate: typing.ClassVar[str] = '0F000' class InvalidLocatorSpecificationError(LocatorError): - sqlstate = '0F001' + sqlstate: typing.ClassVar[str] = '0F001' class InvalidGrantorError(_base.PostgresError): - sqlstate = '0L000' + sqlstate: typing.ClassVar[str] = '0L000' class InvalidGrantOperationError(InvalidGrantorError): - sqlstate = '0LP01' + sqlstate: typing.ClassVar[str] = '0LP01' class InvalidRoleSpecificationError(_base.PostgresError): - sqlstate = '0P000' + sqlstate: typing.ClassVar[str] = '0P000' class DiagnosticsError(_base.PostgresError): - sqlstate = '0Z000' + sqlstate: typing.ClassVar[str] = '0Z000' class StackedDiagnosticsAccessedWithoutActiveHandlerError(DiagnosticsError): - sqlstate = '0Z002' + sqlstate: typing.ClassVar[str] = '0Z002' class CaseNotFoundError(_base.PostgresError): - sqlstate = '20000' + sqlstate: typing.ClassVar[str] = '20000' class CardinalityViolationError(_base.PostgresError): - sqlstate = '21000' + sqlstate: typing.ClassVar[str] = '21000' class DataError(_base.PostgresError): - sqlstate = '22000' + sqlstate: typing.ClassVar[str] = '22000' class ArraySubscriptError(DataError): - sqlstate = '2202E' + sqlstate: typing.ClassVar[str] = '2202E' class CharacterNotInRepertoireError(DataError): - sqlstate = '22021' + sqlstate: typing.ClassVar[str] = '22021' class DatetimeFieldOverflowError(DataError): - sqlstate = '22008' + sqlstate: typing.ClassVar[str] = '22008' class DivisionByZeroError(DataError): - sqlstate = '22012' + sqlstate: typing.ClassVar[str] = '22012' class ErrorInAssignmentError(DataError): - sqlstate = '22005' + sqlstate: typing.ClassVar[str] = '22005' class EscapeCharacterConflictError(DataError): - sqlstate = '2200B' + sqlstate: typing.ClassVar[str] = '2200B' class IndicatorOverflowError(DataError): - sqlstate = '22022' + sqlstate: typing.ClassVar[str] = '22022' class IntervalFieldOverflowError(DataError): - sqlstate = '22015' + sqlstate: typing.ClassVar[str] = '22015' class InvalidArgumentForLogarithmError(DataError): - sqlstate = '2201E' + sqlstate: typing.ClassVar[str] = '2201E' class InvalidArgumentForNtileFunctionError(DataError): - sqlstate = '22014' + sqlstate: typing.ClassVar[str] = '22014' class InvalidArgumentForNthValueFunctionError(DataError): - sqlstate = '22016' + sqlstate: typing.ClassVar[str] = '22016' class InvalidArgumentForPowerFunctionError(DataError): - sqlstate = '2201F' + sqlstate: typing.ClassVar[str] = '2201F' class InvalidArgumentForWidthBucketFunctionError(DataError): - sqlstate = '2201G' + sqlstate: typing.ClassVar[str] = '2201G' class InvalidCharacterValueForCastError(DataError): - sqlstate = '22018' + sqlstate: typing.ClassVar[str] = '22018' class InvalidDatetimeFormatError(DataError): - sqlstate = '22007' + sqlstate: typing.ClassVar[str] = '22007' class InvalidEscapeCharacterError(DataError): - sqlstate = '22019' + sqlstate: typing.ClassVar[str] = '22019' class InvalidEscapeOctetError(DataError): - sqlstate = '2200D' + sqlstate: typing.ClassVar[str] = '2200D' class InvalidEscapeSequenceError(DataError): - sqlstate = '22025' + sqlstate: typing.ClassVar[str] = '22025' class NonstandardUseOfEscapeCharacterError(DataError): - sqlstate = '22P06' + sqlstate: typing.ClassVar[str] = '22P06' class InvalidIndicatorParameterValueError(DataError): - sqlstate = '22010' + sqlstate: typing.ClassVar[str] = '22010' class InvalidParameterValueError(DataError): - sqlstate = '22023' + sqlstate: typing.ClassVar[str] = '22023' class InvalidPrecedingOrFollowingSizeError(DataError): - sqlstate = '22013' + sqlstate: typing.ClassVar[str] = '22013' class InvalidRegularExpressionError(DataError): - sqlstate = '2201B' + sqlstate: typing.ClassVar[str] = '2201B' class InvalidRowCountInLimitClauseError(DataError): - sqlstate = '2201W' + sqlstate: typing.ClassVar[str] = '2201W' class InvalidRowCountInResultOffsetClauseError(DataError): - sqlstate = '2201X' + sqlstate: typing.ClassVar[str] = '2201X' class InvalidTablesampleArgumentError(DataError): - sqlstate = '2202H' + sqlstate: typing.ClassVar[str] = '2202H' class InvalidTablesampleRepeatError(DataError): - sqlstate = '2202G' + sqlstate: typing.ClassVar[str] = '2202G' class InvalidTimeZoneDisplacementValueError(DataError): - sqlstate = '22009' + sqlstate: typing.ClassVar[str] = '22009' class InvalidUseOfEscapeCharacterError(DataError): - sqlstate = '2200C' + sqlstate: typing.ClassVar[str] = '2200C' class MostSpecificTypeMismatchError(DataError): - sqlstate = '2200G' + sqlstate: typing.ClassVar[str] = '2200G' class NullValueNotAllowedError(DataError): - sqlstate = '22004' + sqlstate: typing.ClassVar[str] = '22004' class NullValueNoIndicatorParameterError(DataError): - sqlstate = '22002' + sqlstate: typing.ClassVar[str] = '22002' class NumericValueOutOfRangeError(DataError): - sqlstate = '22003' + sqlstate: typing.ClassVar[str] = '22003' class SequenceGeneratorLimitExceededError(DataError): - sqlstate = '2200H' + sqlstate: typing.ClassVar[str] = '2200H' class StringDataLengthMismatchError(DataError): - sqlstate = '22026' + sqlstate: typing.ClassVar[str] = '22026' class StringDataRightTruncationError(DataError): - sqlstate = '22001' + sqlstate: typing.ClassVar[str] = '22001' class SubstringError(DataError): - sqlstate = '22011' + sqlstate: typing.ClassVar[str] = '22011' class TrimError(DataError): - sqlstate = '22027' + sqlstate: typing.ClassVar[str] = '22027' class UnterminatedCStringError(DataError): - sqlstate = '22024' + sqlstate: typing.ClassVar[str] = '22024' class ZeroLengthCharacterStringError(DataError): - sqlstate = '2200F' + sqlstate: typing.ClassVar[str] = '2200F' class PostgresFloatingPointError(DataError): - sqlstate = '22P01' + sqlstate: typing.ClassVar[str] = '22P01' class InvalidTextRepresentationError(DataError): - sqlstate = '22P02' + sqlstate: typing.ClassVar[str] = '22P02' class InvalidBinaryRepresentationError(DataError): - sqlstate = '22P03' + sqlstate: typing.ClassVar[str] = '22P03' class BadCopyFileFormatError(DataError): - sqlstate = '22P04' + sqlstate: typing.ClassVar[str] = '22P04' class UntranslatableCharacterError(DataError): - sqlstate = '22P05' + sqlstate: typing.ClassVar[str] = '22P05' class NotAnXmlDocumentError(DataError): - sqlstate = '2200L' + sqlstate: typing.ClassVar[str] = '2200L' class InvalidXmlDocumentError(DataError): - sqlstate = '2200M' + sqlstate: typing.ClassVar[str] = '2200M' class InvalidXmlContentError(DataError): - sqlstate = '2200N' + sqlstate: typing.ClassVar[str] = '2200N' class InvalidXmlCommentError(DataError): - sqlstate = '2200S' + sqlstate: typing.ClassVar[str] = '2200S' class InvalidXmlProcessingInstructionError(DataError): - sqlstate = '2200T' + sqlstate: typing.ClassVar[str] = '2200T' class DuplicateJsonObjectKeyValueError(DataError): - sqlstate = '22030' + sqlstate: typing.ClassVar[str] = '22030' class InvalidArgumentForSQLJsonDatetimeFunctionError(DataError): - sqlstate = '22031' + sqlstate: typing.ClassVar[str] = '22031' class InvalidJsonTextError(DataError): - sqlstate = '22032' + sqlstate: typing.ClassVar[str] = '22032' class InvalidSQLJsonSubscriptError(DataError): - sqlstate = '22033' + sqlstate: typing.ClassVar[str] = '22033' class MoreThanOneSQLJsonItemError(DataError): - sqlstate = '22034' + sqlstate: typing.ClassVar[str] = '22034' class NoSQLJsonItemError(DataError): - sqlstate = '22035' + sqlstate: typing.ClassVar[str] = '22035' class NonNumericSQLJsonItemError(DataError): - sqlstate = '22036' + sqlstate: typing.ClassVar[str] = '22036' class NonUniqueKeysInAJsonObjectError(DataError): - sqlstate = '22037' + sqlstate: typing.ClassVar[str] = '22037' class SingletonSQLJsonItemRequiredError(DataError): - sqlstate = '22038' + sqlstate: typing.ClassVar[str] = '22038' class SQLJsonArrayNotFoundError(DataError): - sqlstate = '22039' + sqlstate: typing.ClassVar[str] = '22039' class SQLJsonMemberNotFoundError(DataError): - sqlstate = '2203A' + sqlstate: typing.ClassVar[str] = '2203A' class SQLJsonNumberNotFoundError(DataError): - sqlstate = '2203B' + sqlstate: typing.ClassVar[str] = '2203B' class SQLJsonObjectNotFoundError(DataError): - sqlstate = '2203C' + sqlstate: typing.ClassVar[str] = '2203C' class TooManyJsonArrayElementsError(DataError): - sqlstate = '2203D' + sqlstate: typing.ClassVar[str] = '2203D' class TooManyJsonObjectMembersError(DataError): - sqlstate = '2203E' + sqlstate: typing.ClassVar[str] = '2203E' class SQLJsonScalarRequiredError(DataError): - sqlstate = '2203F' + sqlstate: typing.ClassVar[str] = '2203F' class SQLJsonItemCannotBeCastToTargetTypeError(DataError): - sqlstate = '2203G' + sqlstate: typing.ClassVar[str] = '2203G' class IntegrityConstraintViolationError(_base.PostgresError): - sqlstate = '23000' + sqlstate: typing.ClassVar[str] = '23000' class RestrictViolationError(IntegrityConstraintViolationError): - sqlstate = '23001' + sqlstate: typing.ClassVar[str] = '23001' class NotNullViolationError(IntegrityConstraintViolationError): - sqlstate = '23502' + sqlstate: typing.ClassVar[str] = '23502' class ForeignKeyViolationError(IntegrityConstraintViolationError): - sqlstate = '23503' + sqlstate: typing.ClassVar[str] = '23503' class UniqueViolationError(IntegrityConstraintViolationError): - sqlstate = '23505' + sqlstate: typing.ClassVar[str] = '23505' class CheckViolationError(IntegrityConstraintViolationError): - sqlstate = '23514' + sqlstate: typing.ClassVar[str] = '23514' class ExclusionViolationError(IntegrityConstraintViolationError): - sqlstate = '23P01' + sqlstate: typing.ClassVar[str] = '23P01' class InvalidCursorStateError(_base.PostgresError): - sqlstate = '24000' + sqlstate: typing.ClassVar[str] = '24000' class InvalidTransactionStateError(_base.PostgresError): - sqlstate = '25000' + sqlstate: typing.ClassVar[str] = '25000' class ActiveSQLTransactionError(InvalidTransactionStateError): - sqlstate = '25001' + sqlstate: typing.ClassVar[str] = '25001' class BranchTransactionAlreadyActiveError(InvalidTransactionStateError): - sqlstate = '25002' + sqlstate: typing.ClassVar[str] = '25002' class HeldCursorRequiresSameIsolationLevelError(InvalidTransactionStateError): - sqlstate = '25008' + sqlstate: typing.ClassVar[str] = '25008' class InappropriateAccessModeForBranchTransactionError( InvalidTransactionStateError): - sqlstate = '25003' + sqlstate: typing.ClassVar[str] = '25003' class InappropriateIsolationLevelForBranchTransactionError( InvalidTransactionStateError): - sqlstate = '25004' + sqlstate: typing.ClassVar[str] = '25004' class NoActiveSQLTransactionForBranchTransactionError( InvalidTransactionStateError): - sqlstate = '25005' + sqlstate: typing.ClassVar[str] = '25005' class ReadOnlySQLTransactionError(InvalidTransactionStateError): - sqlstate = '25006' + sqlstate: typing.ClassVar[str] = '25006' class SchemaAndDataStatementMixingNotSupportedError( InvalidTransactionStateError): - sqlstate = '25007' + sqlstate: typing.ClassVar[str] = '25007' class NoActiveSQLTransactionError(InvalidTransactionStateError): - sqlstate = '25P01' + sqlstate: typing.ClassVar[str] = '25P01' class InFailedSQLTransactionError(InvalidTransactionStateError): - sqlstate = '25P02' + sqlstate: typing.ClassVar[str] = '25P02' class IdleInTransactionSessionTimeoutError(InvalidTransactionStateError): - sqlstate = '25P03' + sqlstate: typing.ClassVar[str] = '25P03' class InvalidSQLStatementNameError(_base.PostgresError): - sqlstate = '26000' + sqlstate: typing.ClassVar[str] = '26000' class TriggeredDataChangeViolationError(_base.PostgresError): - sqlstate = '27000' + sqlstate: typing.ClassVar[str] = '27000' class InvalidAuthorizationSpecificationError(_base.PostgresError): - sqlstate = '28000' + sqlstate: typing.ClassVar[str] = '28000' class InvalidPasswordError(InvalidAuthorizationSpecificationError): - sqlstate = '28P01' + sqlstate: typing.ClassVar[str] = '28P01' class DependentPrivilegeDescriptorsStillExistError(_base.PostgresError): - sqlstate = '2B000' + sqlstate: typing.ClassVar[str] = '2B000' class DependentObjectsStillExistError( DependentPrivilegeDescriptorsStillExistError): - sqlstate = '2BP01' + sqlstate: typing.ClassVar[str] = '2BP01' class InvalidTransactionTerminationError(_base.PostgresError): - sqlstate = '2D000' + sqlstate: typing.ClassVar[str] = '2D000' class SQLRoutineError(_base.PostgresError): - sqlstate = '2F000' + sqlstate: typing.ClassVar[str] = '2F000' class FunctionExecutedNoReturnStatementError(SQLRoutineError): - sqlstate = '2F005' + sqlstate: typing.ClassVar[str] = '2F005' class ModifyingSQLDataNotPermittedError(SQLRoutineError): - sqlstate = '2F002' + sqlstate: typing.ClassVar[str] = '2F002' class ProhibitedSQLStatementAttemptedError(SQLRoutineError): - sqlstate = '2F003' + sqlstate: typing.ClassVar[str] = '2F003' class ReadingSQLDataNotPermittedError(SQLRoutineError): - sqlstate = '2F004' + sqlstate: typing.ClassVar[str] = '2F004' class InvalidCursorNameError(_base.PostgresError): - sqlstate = '34000' + sqlstate: typing.ClassVar[str] = '34000' class ExternalRoutineError(_base.PostgresError): - sqlstate = '38000' + sqlstate: typing.ClassVar[str] = '38000' class ContainingSQLNotPermittedError(ExternalRoutineError): - sqlstate = '38001' + sqlstate: typing.ClassVar[str] = '38001' class ModifyingExternalRoutineSQLDataNotPermittedError(ExternalRoutineError): - sqlstate = '38002' + sqlstate: typing.ClassVar[str] = '38002' class ProhibitedExternalRoutineSQLStatementAttemptedError( ExternalRoutineError): - sqlstate = '38003' + sqlstate: typing.ClassVar[str] = '38003' class ReadingExternalRoutineSQLDataNotPermittedError(ExternalRoutineError): - sqlstate = '38004' + sqlstate: typing.ClassVar[str] = '38004' class ExternalRoutineInvocationError(_base.PostgresError): - sqlstate = '39000' + sqlstate: typing.ClassVar[str] = '39000' class InvalidSqlstateReturnedError(ExternalRoutineInvocationError): - sqlstate = '39001' + sqlstate: typing.ClassVar[str] = '39001' class NullValueInExternalRoutineNotAllowedError( ExternalRoutineInvocationError): - sqlstate = '39004' + sqlstate: typing.ClassVar[str] = '39004' class TriggerProtocolViolatedError(ExternalRoutineInvocationError): - sqlstate = '39P01' + sqlstate: typing.ClassVar[str] = '39P01' class SrfProtocolViolatedError(ExternalRoutineInvocationError): - sqlstate = '39P02' + sqlstate: typing.ClassVar[str] = '39P02' class EventTriggerProtocolViolatedError(ExternalRoutineInvocationError): - sqlstate = '39P03' + sqlstate: typing.ClassVar[str] = '39P03' class SavepointError(_base.PostgresError): - sqlstate = '3B000' + sqlstate: typing.ClassVar[str] = '3B000' class InvalidSavepointSpecificationError(SavepointError): - sqlstate = '3B001' + sqlstate: typing.ClassVar[str] = '3B001' class InvalidCatalogNameError(_base.PostgresError): - sqlstate = '3D000' + sqlstate: typing.ClassVar[str] = '3D000' class InvalidSchemaNameError(_base.PostgresError): - sqlstate = '3F000' + sqlstate: typing.ClassVar[str] = '3F000' class TransactionRollbackError(_base.PostgresError): - sqlstate = '40000' + sqlstate: typing.ClassVar[str] = '40000' class TransactionIntegrityConstraintViolationError(TransactionRollbackError): - sqlstate = '40002' + sqlstate: typing.ClassVar[str] = '40002' class SerializationError(TransactionRollbackError): - sqlstate = '40001' + sqlstate: typing.ClassVar[str] = '40001' class StatementCompletionUnknownError(TransactionRollbackError): - sqlstate = '40003' + sqlstate: typing.ClassVar[str] = '40003' class DeadlockDetectedError(TransactionRollbackError): - sqlstate = '40P01' + sqlstate: typing.ClassVar[str] = '40P01' class SyntaxOrAccessError(_base.PostgresError): - sqlstate = '42000' + sqlstate: typing.ClassVar[str] = '42000' class PostgresSyntaxError(SyntaxOrAccessError): - sqlstate = '42601' + sqlstate: typing.ClassVar[str] = '42601' class InsufficientPrivilegeError(SyntaxOrAccessError): - sqlstate = '42501' + sqlstate: typing.ClassVar[str] = '42501' class CannotCoerceError(SyntaxOrAccessError): - sqlstate = '42846' + sqlstate: typing.ClassVar[str] = '42846' class GroupingError(SyntaxOrAccessError): - sqlstate = '42803' + sqlstate: typing.ClassVar[str] = '42803' class WindowingError(SyntaxOrAccessError): - sqlstate = '42P20' + sqlstate: typing.ClassVar[str] = '42P20' class InvalidRecursionError(SyntaxOrAccessError): - sqlstate = '42P19' + sqlstate: typing.ClassVar[str] = '42P19' class InvalidForeignKeyError(SyntaxOrAccessError): - sqlstate = '42830' + sqlstate: typing.ClassVar[str] = '42830' class InvalidNameError(SyntaxOrAccessError): - sqlstate = '42602' + sqlstate: typing.ClassVar[str] = '42602' class NameTooLongError(SyntaxOrAccessError): - sqlstate = '42622' + sqlstate: typing.ClassVar[str] = '42622' class ReservedNameError(SyntaxOrAccessError): - sqlstate = '42939' + sqlstate: typing.ClassVar[str] = '42939' class DatatypeMismatchError(SyntaxOrAccessError): - sqlstate = '42804' + sqlstate: typing.ClassVar[str] = '42804' class IndeterminateDatatypeError(SyntaxOrAccessError): - sqlstate = '42P18' + sqlstate: typing.ClassVar[str] = '42P18' class CollationMismatchError(SyntaxOrAccessError): - sqlstate = '42P21' + sqlstate: typing.ClassVar[str] = '42P21' class IndeterminateCollationError(SyntaxOrAccessError): - sqlstate = '42P22' + sqlstate: typing.ClassVar[str] = '42P22' class WrongObjectTypeError(SyntaxOrAccessError): - sqlstate = '42809' + sqlstate: typing.ClassVar[str] = '42809' class GeneratedAlwaysError(SyntaxOrAccessError): - sqlstate = '428C9' + sqlstate: typing.ClassVar[str] = '428C9' class UndefinedColumnError(SyntaxOrAccessError): - sqlstate = '42703' + sqlstate: typing.ClassVar[str] = '42703' class UndefinedFunctionError(SyntaxOrAccessError): - sqlstate = '42883' + sqlstate: typing.ClassVar[str] = '42883' class UndefinedTableError(SyntaxOrAccessError): - sqlstate = '42P01' + sqlstate: typing.ClassVar[str] = '42P01' class UndefinedParameterError(SyntaxOrAccessError): - sqlstate = '42P02' + sqlstate: typing.ClassVar[str] = '42P02' class UndefinedObjectError(SyntaxOrAccessError): - sqlstate = '42704' + sqlstate: typing.ClassVar[str] = '42704' class DuplicateColumnError(SyntaxOrAccessError): - sqlstate = '42701' + sqlstate: typing.ClassVar[str] = '42701' class DuplicateCursorError(SyntaxOrAccessError): - sqlstate = '42P03' + sqlstate: typing.ClassVar[str] = '42P03' class DuplicateDatabaseError(SyntaxOrAccessError): - sqlstate = '42P04' + sqlstate: typing.ClassVar[str] = '42P04' class DuplicateFunctionError(SyntaxOrAccessError): - sqlstate = '42723' + sqlstate: typing.ClassVar[str] = '42723' class DuplicatePreparedStatementError(SyntaxOrAccessError): - sqlstate = '42P05' + sqlstate: typing.ClassVar[str] = '42P05' class DuplicateSchemaError(SyntaxOrAccessError): - sqlstate = '42P06' + sqlstate: typing.ClassVar[str] = '42P06' class DuplicateTableError(SyntaxOrAccessError): - sqlstate = '42P07' + sqlstate: typing.ClassVar[str] = '42P07' class DuplicateAliasError(SyntaxOrAccessError): - sqlstate = '42712' + sqlstate: typing.ClassVar[str] = '42712' class DuplicateObjectError(SyntaxOrAccessError): - sqlstate = '42710' + sqlstate: typing.ClassVar[str] = '42710' class AmbiguousColumnError(SyntaxOrAccessError): - sqlstate = '42702' + sqlstate: typing.ClassVar[str] = '42702' class AmbiguousFunctionError(SyntaxOrAccessError): - sqlstate = '42725' + sqlstate: typing.ClassVar[str] = '42725' class AmbiguousParameterError(SyntaxOrAccessError): - sqlstate = '42P08' + sqlstate: typing.ClassVar[str] = '42P08' class AmbiguousAliasError(SyntaxOrAccessError): - sqlstate = '42P09' + sqlstate: typing.ClassVar[str] = '42P09' class InvalidColumnReferenceError(SyntaxOrAccessError): - sqlstate = '42P10' + sqlstate: typing.ClassVar[str] = '42P10' class InvalidColumnDefinitionError(SyntaxOrAccessError): - sqlstate = '42611' + sqlstate: typing.ClassVar[str] = '42611' class InvalidCursorDefinitionError(SyntaxOrAccessError): - sqlstate = '42P11' + sqlstate: typing.ClassVar[str] = '42P11' class InvalidDatabaseDefinitionError(SyntaxOrAccessError): - sqlstate = '42P12' + sqlstate: typing.ClassVar[str] = '42P12' class InvalidFunctionDefinitionError(SyntaxOrAccessError): - sqlstate = '42P13' + sqlstate: typing.ClassVar[str] = '42P13' class InvalidPreparedStatementDefinitionError(SyntaxOrAccessError): - sqlstate = '42P14' + sqlstate: typing.ClassVar[str] = '42P14' class InvalidSchemaDefinitionError(SyntaxOrAccessError): - sqlstate = '42P15' + sqlstate: typing.ClassVar[str] = '42P15' class InvalidTableDefinitionError(SyntaxOrAccessError): - sqlstate = '42P16' + sqlstate: typing.ClassVar[str] = '42P16' class InvalidObjectDefinitionError(SyntaxOrAccessError): - sqlstate = '42P17' + sqlstate: typing.ClassVar[str] = '42P17' class WithCheckOptionViolationError(_base.PostgresError): - sqlstate = '44000' + sqlstate: typing.ClassVar[str] = '44000' class InsufficientResourcesError(_base.PostgresError): - sqlstate = '53000' + sqlstate: typing.ClassVar[str] = '53000' class DiskFullError(InsufficientResourcesError): - sqlstate = '53100' + sqlstate: typing.ClassVar[str] = '53100' class OutOfMemoryError(InsufficientResourcesError): - sqlstate = '53200' + sqlstate: typing.ClassVar[str] = '53200' class TooManyConnectionsError(InsufficientResourcesError): - sqlstate = '53300' + sqlstate: typing.ClassVar[str] = '53300' class ConfigurationLimitExceededError(InsufficientResourcesError): - sqlstate = '53400' + sqlstate: typing.ClassVar[str] = '53400' class ProgramLimitExceededError(_base.PostgresError): - sqlstate = '54000' + sqlstate: typing.ClassVar[str] = '54000' class StatementTooComplexError(ProgramLimitExceededError): - sqlstate = '54001' + sqlstate: typing.ClassVar[str] = '54001' class TooManyColumnsError(ProgramLimitExceededError): - sqlstate = '54011' + sqlstate: typing.ClassVar[str] = '54011' class TooManyArgumentsError(ProgramLimitExceededError): - sqlstate = '54023' + sqlstate: typing.ClassVar[str] = '54023' class ObjectNotInPrerequisiteStateError(_base.PostgresError): - sqlstate = '55000' + sqlstate: typing.ClassVar[str] = '55000' class ObjectInUseError(ObjectNotInPrerequisiteStateError): - sqlstate = '55006' + sqlstate: typing.ClassVar[str] = '55006' class CantChangeRuntimeParamError(ObjectNotInPrerequisiteStateError): - sqlstate = '55P02' + sqlstate: typing.ClassVar[str] = '55P02' class LockNotAvailableError(ObjectNotInPrerequisiteStateError): - sqlstate = '55P03' + sqlstate: typing.ClassVar[str] = '55P03' class UnsafeNewEnumValueUsageError(ObjectNotInPrerequisiteStateError): - sqlstate = '55P04' + sqlstate: typing.ClassVar[str] = '55P04' class OperatorInterventionError(_base.PostgresError): - sqlstate = '57000' + sqlstate: typing.ClassVar[str] = '57000' class QueryCanceledError(OperatorInterventionError): - sqlstate = '57014' + sqlstate: typing.ClassVar[str] = '57014' class AdminShutdownError(OperatorInterventionError): - sqlstate = '57P01' + sqlstate: typing.ClassVar[str] = '57P01' class CrashShutdownError(OperatorInterventionError): - sqlstate = '57P02' + sqlstate: typing.ClassVar[str] = '57P02' class CannotConnectNowError(OperatorInterventionError): - sqlstate = '57P03' + sqlstate: typing.ClassVar[str] = '57P03' class DatabaseDroppedError(OperatorInterventionError): - sqlstate = '57P04' + sqlstate: typing.ClassVar[str] = '57P04' class IdleSessionTimeoutError(OperatorInterventionError): - sqlstate = '57P05' + sqlstate: typing.ClassVar[str] = '57P05' class PostgresSystemError(_base.PostgresError): - sqlstate = '58000' + sqlstate: typing.ClassVar[str] = '58000' class PostgresIOError(PostgresSystemError): - sqlstate = '58030' + sqlstate: typing.ClassVar[str] = '58030' class UndefinedFileError(PostgresSystemError): - sqlstate = '58P01' + sqlstate: typing.ClassVar[str] = '58P01' class DuplicateFileError(PostgresSystemError): - sqlstate = '58P02' + sqlstate: typing.ClassVar[str] = '58P02' class SnapshotTooOldError(_base.PostgresError): - sqlstate = '72000' + sqlstate: typing.ClassVar[str] = '72000' class ConfigFileError(_base.PostgresError): - sqlstate = 'F0000' + sqlstate: typing.ClassVar[str] = 'F0000' class LockFileExistsError(ConfigFileError): - sqlstate = 'F0001' + sqlstate: typing.ClassVar[str] = 'F0001' class FDWError(_base.PostgresError): - sqlstate = 'HV000' + sqlstate: typing.ClassVar[str] = 'HV000' class FDWColumnNameNotFoundError(FDWError): - sqlstate = 'HV005' + sqlstate: typing.ClassVar[str] = 'HV005' class FDWDynamicParameterValueNeededError(FDWError): - sqlstate = 'HV002' + sqlstate: typing.ClassVar[str] = 'HV002' class FDWFunctionSequenceError(FDWError): - sqlstate = 'HV010' + sqlstate: typing.ClassVar[str] = 'HV010' class FDWInconsistentDescriptorInformationError(FDWError): - sqlstate = 'HV021' + sqlstate: typing.ClassVar[str] = 'HV021' class FDWInvalidAttributeValueError(FDWError): - sqlstate = 'HV024' + sqlstate: typing.ClassVar[str] = 'HV024' class FDWInvalidColumnNameError(FDWError): - sqlstate = 'HV007' + sqlstate: typing.ClassVar[str] = 'HV007' class FDWInvalidColumnNumberError(FDWError): - sqlstate = 'HV008' + sqlstate: typing.ClassVar[str] = 'HV008' class FDWInvalidDataTypeError(FDWError): - sqlstate = 'HV004' + sqlstate: typing.ClassVar[str] = 'HV004' class FDWInvalidDataTypeDescriptorsError(FDWError): - sqlstate = 'HV006' + sqlstate: typing.ClassVar[str] = 'HV006' class FDWInvalidDescriptorFieldIdentifierError(FDWError): - sqlstate = 'HV091' + sqlstate: typing.ClassVar[str] = 'HV091' class FDWInvalidHandleError(FDWError): - sqlstate = 'HV00B' + sqlstate: typing.ClassVar[str] = 'HV00B' class FDWInvalidOptionIndexError(FDWError): - sqlstate = 'HV00C' + sqlstate: typing.ClassVar[str] = 'HV00C' class FDWInvalidOptionNameError(FDWError): - sqlstate = 'HV00D' + sqlstate: typing.ClassVar[str] = 'HV00D' class FDWInvalidStringLengthOrBufferLengthError(FDWError): - sqlstate = 'HV090' + sqlstate: typing.ClassVar[str] = 'HV090' class FDWInvalidStringFormatError(FDWError): - sqlstate = 'HV00A' + sqlstate: typing.ClassVar[str] = 'HV00A' class FDWInvalidUseOfNullPointerError(FDWError): - sqlstate = 'HV009' + sqlstate: typing.ClassVar[str] = 'HV009' class FDWTooManyHandlesError(FDWError): - sqlstate = 'HV014' + sqlstate: typing.ClassVar[str] = 'HV014' class FDWOutOfMemoryError(FDWError): - sqlstate = 'HV001' + sqlstate: typing.ClassVar[str] = 'HV001' class FDWNoSchemasError(FDWError): - sqlstate = 'HV00P' + sqlstate: typing.ClassVar[str] = 'HV00P' class FDWOptionNameNotFoundError(FDWError): - sqlstate = 'HV00J' + sqlstate: typing.ClassVar[str] = 'HV00J' class FDWReplyHandleError(FDWError): - sqlstate = 'HV00K' + sqlstate: typing.ClassVar[str] = 'HV00K' class FDWSchemaNotFoundError(FDWError): - sqlstate = 'HV00Q' + sqlstate: typing.ClassVar[str] = 'HV00Q' class FDWTableNotFoundError(FDWError): - sqlstate = 'HV00R' + sqlstate: typing.ClassVar[str] = 'HV00R' class FDWUnableToCreateExecutionError(FDWError): - sqlstate = 'HV00L' + sqlstate: typing.ClassVar[str] = 'HV00L' class FDWUnableToCreateReplyError(FDWError): - sqlstate = 'HV00M' + sqlstate: typing.ClassVar[str] = 'HV00M' class FDWUnableToEstablishConnectionError(FDWError): - sqlstate = 'HV00N' + sqlstate: typing.ClassVar[str] = 'HV00N' class PLPGSQLError(_base.PostgresError): - sqlstate = 'P0000' + sqlstate: typing.ClassVar[str] = 'P0000' class RaiseError(PLPGSQLError): - sqlstate = 'P0001' + sqlstate: typing.ClassVar[str] = 'P0001' class NoDataFoundError(PLPGSQLError): - sqlstate = 'P0002' + sqlstate: typing.ClassVar[str] = 'P0002' class TooManyRowsError(PLPGSQLError): - sqlstate = 'P0003' + sqlstate: typing.ClassVar[str] = 'P0003' class AssertError(PLPGSQLError): - sqlstate = 'P0004' + sqlstate: typing.ClassVar[str] = 'P0004' class InternalServerError(_base.PostgresError): - sqlstate = 'XX000' + sqlstate: typing.ClassVar[str] = 'XX000' class DataCorruptedError(InternalServerError): - sqlstate = 'XX001' + sqlstate: typing.ClassVar[str] = 'XX001' class IndexCorruptedError(InternalServerError): - sqlstate = 'XX002' + sqlstate: typing.ClassVar[str] = 'XX002' -__all__ = ( +__all__ = [ 'ActiveSQLTransactionError', 'AdminShutdownError', 'AmbiguousAliasError', 'AmbiguousColumnError', 'AmbiguousFunctionError', 'AmbiguousParameterError', @@ -1193,6 +1196,6 @@ class IndexCorruptedError(InternalServerError): 'UnterminatedCStringError', 'UntranslatableCharacterError', 'WindowingError', 'WithCheckOptionViolationError', 'WrongObjectTypeError', 'ZeroLengthCharacterStringError' -) +] __all__ += _base.__all__ diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 00e9699a..5763e180 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -4,169 +4,36 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncpg -import sys -import textwrap +import typing +if typing.TYPE_CHECKING: + import sys -__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + +from ._postgres_message import PostgresMessage as PostgresMessage + +__all__ = ['PostgresError', 'FatalPostgresError', 'UnknownPostgresError', 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', 'ClientConfigurationError', 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched', - 'UnsupportedServerFeatureError') - - -def _is_asyncpg_class(cls): - modname = cls.__module__ - return modname == 'asyncpg' or modname.startswith('asyncpg.') - - -class PostgresMessageMeta(type): - - _message_map = {} - _field_map = { - 'S': 'severity', - 'V': 'severity_en', - 'C': 'sqlstate', - 'M': 'message', - 'D': 'detail', - 'H': 'hint', - 'P': 'position', - 'p': 'internal_position', - 'q': 'internal_query', - 'W': 'context', - 's': 'schema_name', - 't': 'table_name', - 'c': 'column_name', - 'd': 'data_type_name', - 'n': 'constraint_name', - 'F': 'server_source_filename', - 'L': 'server_source_line', - 'R': 'server_source_function' - } - - def __new__(mcls, name, bases, dct): - cls = super().__new__(mcls, name, bases, dct) - if cls.__module__ == mcls.__module__ and name == 'PostgresMessage': - for f in mcls._field_map.values(): - setattr(cls, f, None) - - if _is_asyncpg_class(cls): - mod = sys.modules[cls.__module__] - if hasattr(mod, name): - raise RuntimeError('exception class redefinition: {}'.format( - name)) - - code = dct.get('sqlstate') - if code is not None: - existing = mcls._message_map.get(code) - if existing is not None: - raise TypeError('{} has duplicate SQLSTATE code, which is' - 'already defined by {}'.format( - name, existing.__name__)) - mcls._message_map[code] = cls - - return cls - - @classmethod - def get_message_class_for_sqlstate(mcls, code): - return mcls._message_map.get(code, UnknownPostgresError) + 'UnsupportedServerFeatureError'] - -class PostgresMessage(metaclass=PostgresMessageMeta): - - @classmethod - def _get_error_class(cls, fields): - sqlstate = fields.get('C') - return type(cls).get_message_class_for_sqlstate(sqlstate) - - @classmethod - def _get_error_dict(cls, fields, query): - dct = { - 'query': query - } - - field_map = type(cls)._field_map - for k, v in fields.items(): - field = field_map.get(k) - if field: - dct[field] = v - - return dct - - @classmethod - def _make_constructor(cls, fields, query=None): - dct = cls._get_error_dict(fields, query) - - exccls = cls._get_error_class(fields) - message = dct.get('message', '') - - # PostgreSQL will raise an exception when it detects - # that the result type of the query has changed from - # when the statement was prepared. - # - # The original error is somewhat cryptic and unspecific, - # so we raise a custom subclass that is easier to handle - # and identify. - # - # Note that we specifically do not rely on the error - # message, as it is localizable. - is_icse = ( - exccls.__name__ == 'FeatureNotSupportedError' and - _is_asyncpg_class(exccls) and - dct.get('server_source_function') == 'RevalidateCachedQuery' - ) - - if is_icse: - exceptions = sys.modules[exccls.__module__] - exccls = exceptions.InvalidCachedStatementError - message = ('cached statement plan is invalid due to a database ' - 'schema or configuration change') - - is_prepared_stmt_error = ( - exccls.__name__ in ('DuplicatePreparedStatementError', - 'InvalidSQLStatementNameError') and - _is_asyncpg_class(exccls) - ) - - if is_prepared_stmt_error: - hint = dct.get('hint', '') - hint += textwrap.dedent("""\ - - NOTE: pgbouncer with pool_mode set to "transaction" or - "statement" does not support prepared statements properly. - You have two options: - - * if you are using pgbouncer for connection pooling to a - single server, switch to the connection pool functionality - provided by asyncpg, it is a much better option for this - purpose; - - * if you have no option of avoiding the use of pgbouncer, - then you can set statement_cache_size to 0 when creating - the asyncpg connection object. - """) - - dct['hint'] = hint - - return exccls, message, dct - - def as_dict(self): - dct = {} - for f in type(self)._field_map.values(): - val = getattr(self, f) - if val is not None: - dct[f] = val - return dct +_PM = typing.TypeVar('_PM', bound='PostgresMessage') class PostgresError(PostgresMessage, Exception): """Base class for all Postgres errors.""" - def __str__(self): - msg = self.args[0] + def __str__(self) -> str: + msg: str = self.args[0] if self.detail: msg += '\nDETAIL: {}'.format(self.detail) if self.hint: @@ -175,7 +42,7 @@ def __str__(self): return msg @classmethod - def new(cls, fields, query=None): + def new(cls, fields: dict[str, str], query: str | None = None) -> Self: exccls, message, dct = cls._make_constructor(fields, query) ex = exccls(message) ex.__dict__.update(dct) @@ -191,11 +58,20 @@ class UnknownPostgresError(FatalPostgresError): class InterfaceMessage: - def __init__(self, *, detail=None, hint=None): + args: tuple[str, ...] + detail: str | None + hint: str | None + + def __init__( + self, + *, + detail: str | None = None, + hint: str | None = None, + ) -> None: self.detail = detail self.hint = hint - def __str__(self): + def __str__(self) -> str: msg = self.args[0] if self.detail: msg += '\nDETAIL: {}'.format(self.detail) @@ -208,11 +84,17 @@ def __str__(self): class InterfaceError(InterfaceMessage, Exception): """An error caused by improper use of asyncpg API.""" - def __init__(self, msg, *, detail=None, hint=None): + def __init__( + self, + msg: str, + *, + detail: str | None = None, + hint: str | None = None, + ) -> None: InterfaceMessage.__init__(self, detail=detail, hint=hint) Exception.__init__(self, msg) - def with_msg(self, msg): + def with_msg(self, msg: str) -> Self: return type(self)( msg, detail=self.detail, @@ -241,7 +123,13 @@ class UnsupportedServerFeatureError(InterfaceError): class InterfaceWarning(InterfaceMessage, UserWarning): """A warning caused by an improper use of asyncpg API.""" - def __init__(self, msg, *, detail=None, hint=None): + def __init__( + self, + msg: str, + *, + detail: str | None = None, + hint: str | None = None, + ) -> None: InterfaceMessage.__init__(self, detail=detail, hint=hint) UserWarning.__init__(self, msg) @@ -261,7 +149,18 @@ class TargetServerAttributeNotMatched(InternalClientError): class OutdatedSchemaCacheError(InternalClientError): """A value decoding error caused by a schema change before row fetching.""" - def __init__(self, msg, *, schema=None, data_type=None, position=None): + schema_name: str | None + data_type_name: str | None + position: str | None + + def __init__( + self, + msg: str, + *, + schema: str | None = None, + data_type: str | None = None, + position: str | None = None, + ) -> None: super().__init__(msg) self.schema_name = schema self.data_type_name = data_type @@ -271,15 +170,18 @@ def __init__(self, msg, *, schema=None, data_type=None, position=None): class PostgresLogMessage(PostgresMessage): """A base class for non-error server messages.""" - def __str__(self): + def __str__(self) -> str: return '{}: {}'.format(type(self).__name__, self.message) - def __setattr__(self, name, val): + def __setattr__(self, name: str, val: object) -> None: raise TypeError('instances of {} are immutable'.format( type(self).__name__)) @classmethod - def new(cls, fields, query=None): + def new( + cls: type[_PM], fields: dict[str, str], query: str | None = None + ) -> PostgresMessage: + exccls: type[PostgresMessage] exccls, message_text, dct = cls._make_constructor(fields, query) if exccls is UnknownPostgresError: @@ -291,7 +193,7 @@ def new(cls, fields, query=None): exccls = asyncpg.PostgresWarning if issubclass(exccls, (BaseException, Warning)): - msg = exccls(message_text) + msg: PostgresMessage = exccls(message_text) else: msg = exccls() diff --git a/asyncpg/exceptions/_postgres_message.py b/asyncpg/exceptions/_postgres_message.py new file mode 100644 index 00000000..c281d7dd --- /dev/null +++ b/asyncpg/exceptions/_postgres_message.py @@ -0,0 +1,155 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +from __future__ import annotations + +import asyncpg +import sys +import textwrap + + +def _is_asyncpg_class(cls): + modname = cls.__module__ + return modname == 'asyncpg' or modname.startswith('asyncpg.') + + +class PostgresMessageMeta(type): + + _message_map = {} + _field_map = { + 'S': 'severity', + 'V': 'severity_en', + 'C': 'sqlstate', + 'M': 'message', + 'D': 'detail', + 'H': 'hint', + 'P': 'position', + 'p': 'internal_position', + 'q': 'internal_query', + 'W': 'context', + 's': 'schema_name', + 't': 'table_name', + 'c': 'column_name', + 'd': 'data_type_name', + 'n': 'constraint_name', + 'F': 'server_source_filename', + 'L': 'server_source_line', + 'R': 'server_source_function' + } + + def __new__(mcls, name, bases, dct): + cls = super().__new__(mcls, name, bases, dct) + if cls.__module__ == mcls.__module__ and name == 'PostgresMessage': + for f in mcls._field_map.values(): + setattr(cls, f, None) + + if _is_asyncpg_class(cls): + mod = sys.modules[cls.__module__] + if hasattr(mod, name): + raise RuntimeError('exception class redefinition: {}'.format( + name)) + + code = dct.get('sqlstate') + if code is not None: + existing = mcls._message_map.get(code) + if existing is not None: + raise TypeError('{} has duplicate SQLSTATE code, which is' + 'already defined by {}'.format( + name, existing.__name__)) + mcls._message_map[code] = cls + + return cls + + @classmethod + def get_message_class_for_sqlstate(mcls, code): + return mcls._message_map.get(code, asyncpg.UnknownPostgresError) + + +class PostgresMessage(metaclass=PostgresMessageMeta): + + @classmethod + def _get_error_class(cls, fields): + sqlstate = fields.get('C') + return type(cls).get_message_class_for_sqlstate(sqlstate) + + @classmethod + def _get_error_dict(cls, fields, query): + dct = { + 'query': query + } + + field_map = type(cls)._field_map + for k, v in fields.items(): + field = field_map.get(k) + if field: + dct[field] = v + + return dct + + @classmethod + def _make_constructor(cls, fields, query=None): + dct = cls._get_error_dict(fields, query) + + exccls = cls._get_error_class(fields) + message = dct.get('message', '') + + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. + # + # The original error is somewhat cryptic and unspecific, + # so we raise a custom subclass that is easier to handle + # and identify. + # + # Note that we specifically do not rely on the error + # message, as it is localizable. + is_icse = ( + exccls.__name__ == 'FeatureNotSupportedError' and + _is_asyncpg_class(exccls) and + dct.get('server_source_function') == 'RevalidateCachedQuery' + ) + + if is_icse: + exceptions = sys.modules[exccls.__module__] + exccls = exceptions.InvalidCachedStatementError + message = ('cached statement plan is invalid due to a database ' + 'schema or configuration change') + + is_prepared_stmt_error = ( + exccls.__name__ in ('DuplicatePreparedStatementError', + 'InvalidSQLStatementNameError') and + _is_asyncpg_class(exccls) + ) + + if is_prepared_stmt_error: + hint = dct.get('hint', '') + hint += textwrap.dedent("""\ + + NOTE: pgbouncer with pool_mode set to "transaction" or + "statement" does not support prepared statements properly. + You have two options: + + * if you are using pgbouncer for connection pooling to a + single server, switch to the connection pool functionality + provided by asyncpg, it is a much better option for this + purpose; + + * if you have no option of avoiding the use of pgbouncer, + then you can set statement_cache_size to 0 when creating + the asyncpg connection object. + """) + + dct['hint'] = hint + + return exccls, message, dct + + def as_dict(self): + dct = {} + for f in type(self)._field_map.values(): + val = getattr(self, f) + if val is not None: + dct[f] = val + return dct diff --git a/asyncpg/exceptions/_postgres_message.pyi b/asyncpg/exceptions/_postgres_message.pyi new file mode 100644 index 00000000..7d2d4e75 --- /dev/null +++ b/asyncpg/exceptions/_postgres_message.pyi @@ -0,0 +1,36 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +import typing + +_PM = typing.TypeVar('_PM', bound=PostgresMessage) + +class PostgresMessageMeta(type): ... + +class PostgresMessage(metaclass=PostgresMessageMeta): + severity: str | None + severity_en: str | None + sqlstate: typing.ClassVar[str] + message: str + detail: str | None + hint: str | None + position: str | None + internal_position: str | None + internal_query: str | None + context: str | None + schema_name: str | None + table_name: str | None + column_name: str | None + data_type_name: str | None + constraint_name: str | None + server_source_filename: str | None + server_source_line: str | None + server_source_function: str | None + @classmethod + def _make_constructor( + cls: type[_PM], fields: dict[str, str], query: str | None = ... + ) -> tuple[type[_PM], str, dict[str, str]]: ... + def as_dict(self) -> dict[str, str]: ... diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py index 6c2caf03..95ce0f0a 100644 --- a/asyncpg/introspection.py +++ b/asyncpg/introspection.py @@ -4,8 +4,15 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations -_TYPEINFO_13 = '''\ +import typing + +if typing.TYPE_CHECKING: + from . import protocol + + +_TYPEINFO_13: typing.Final = '''\ ( SELECT t.oid AS oid, @@ -82,7 +89,7 @@ ''' -INTRO_LOOKUP_TYPES_13 = '''\ +INTRO_LOOKUP_TYPES_13: typing.Final = '''\ WITH RECURSIVE typeinfo_tree( oid, ns, name, kind, basetype, elemtype, elemdelim, range_subtype, attrtypoids, attrnames, depth) @@ -124,7 +131,7 @@ '''.format(typeinfo=_TYPEINFO_13) -_TYPEINFO = '''\ +_TYPEINFO: typing.Final = '''\ ( SELECT t.oid AS oid, @@ -206,7 +213,7 @@ ''' -INTRO_LOOKUP_TYPES = '''\ +INTRO_LOOKUP_TYPES: typing.Final = '''\ WITH RECURSIVE typeinfo_tree( oid, ns, name, kind, basetype, elemtype, elemdelim, range_subtype, attrtypoids, attrnames, depth) @@ -248,7 +255,7 @@ '''.format(typeinfo=_TYPEINFO) -TYPE_BY_NAME = '''\ +TYPE_BY_NAME: typing.Final = '''\ SELECT t.oid, t.typelem AS elemtype, @@ -274,19 +281,19 @@ # 'b' for a base type, 'd' for a domain, 'e' for enum. -SCALAR_TYPE_KINDS = (b'b', b'd', b'e') +SCALAR_TYPE_KINDS: typing.Final = (b'b', b'd', b'e') -def is_scalar_type(typeinfo) -> bool: +def is_scalar_type(typeinfo: protocol.Record) -> bool: return ( typeinfo['kind'] in SCALAR_TYPE_KINDS and not typeinfo['elemtype'] ) -def is_domain_type(typeinfo) -> bool: - return typeinfo['kind'] == b'd' +def is_domain_type(typeinfo: protocol.Record) -> bool: + return typing.cast(bytes, typeinfo['kind']) == b'd' -def is_composite_type(typeinfo) -> bool: - return typeinfo['kind'] == b'c' +def is_composite_type(typeinfo: protocol.Record) -> bool: + return typing.cast(bytes, typeinfo['kind']) == b'c' diff --git a/asyncpg/pgproto b/asyncpg/pgproto index 1c3cad14..dbb69452 160000 --- a/asyncpg/pgproto +++ b/asyncpg/pgproto @@ -1 +1 @@ -Subproject commit 1c3cad14d53c8f3088106f4eab8f612b7293569b +Subproject commit dbb69452baaac89ae46cbae0fb6b4a267083d16f diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 06e698df..1544a6c9 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -4,94 +4,53 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio -import functools -import inspect import logging import time +import typing import warnings from . import compat from . import connection from . import exceptions +from . import pool_connection_proxy from . import protocol +if typing.TYPE_CHECKING: + import sys -logger = logging.getLogger(__name__) + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self + from . import connect_utils -class PoolConnectionProxyMeta(type): +_ConnectionT = typing.TypeVar( + '_ConnectionT', bound=connection.Connection[typing.Any] +) +_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record) +_OtherRecordT = typing.TypeVar('_OtherRecordT', bound=protocol.Record) - def __new__(mcls, name, bases, dct, *, wrap=False): - if wrap: - for attrname in dir(connection.Connection): - if attrname.startswith('_') or attrname in dct: - continue +_logger = logging.getLogger(__name__) - meth = getattr(connection.Connection, attrname) - if not inspect.isfunction(meth): - continue - wrapper = mcls._wrap_connection_method(attrname) - wrapper = functools.update_wrapper(wrapper, meth) - dct[attrname] = wrapper - - if '__doc__' not in dct: - dct['__doc__'] = connection.Connection.__doc__ - - return super().__new__(mcls, name, bases, dct) - - @staticmethod - def _wrap_connection_method(meth_name): - def call_con_method(self, *args, **kwargs): - # This method will be owned by PoolConnectionProxy class. - if self._con is None: - raise exceptions.InterfaceError( - 'cannot call Connection.{}(): ' - 'connection has been released back to the pool'.format( - meth_name)) - - meth = getattr(self._con.__class__, meth_name) - return meth(self._con, *args, **kwargs) - - return call_con_method - - -class PoolConnectionProxy(connection._ConnectionProxy, - metaclass=PoolConnectionProxyMeta, - wrap=True): - - __slots__ = ('_con', '_holder') - - def __init__(self, holder: 'PoolConnectionHolder', - con: connection.Connection): - self._con = con - self._holder = holder - con._set_proxy(self) - - def __getattr__(self, attr): - # Proxy all unresolved attributes to the wrapped Connection object. - return getattr(self._con, attr) - - def _detach(self) -> connection.Connection: - if self._con is None: - return +class _SetupCallback(typing.Protocol[_RecordT]): + async def __call__( + self, + __proxy: pool_connection_proxy.PoolConnectionProxy[_RecordT] + ) -> None: + ... - con, self._con = self._con, None - con._set_proxy(None) - return con - def __repr__(self): - if self._con is None: - return '<{classname} [released] {id:#x}>'.format( - classname=self.__class__.__name__, id=id(self)) - else: - return '<{classname} {con!r} {id:#x}>'.format( - classname=self.__class__.__name__, con=self._con, id=id(self)) +class _InitCallback(typing.Protocol[_RecordT]): + async def __call__(self, __con: connection.Connection[_RecordT]) -> None: + ... -class PoolConnectionHolder: +class PoolConnectionHolder(typing.Generic[_RecordT]): __slots__ = ('_con', '_pool', '_loop', '_proxy', '_max_queries', '_setup', @@ -99,7 +58,25 @@ class PoolConnectionHolder: '_inactive_callback', '_timeout', '_generation') - def __init__(self, pool, *, max_queries, setup, max_inactive_time): + _con: connection.Connection[_RecordT] | None + _pool: Pool[_RecordT] + _proxy: pool_connection_proxy.PoolConnectionProxy[_RecordT] | None + _max_queries: int + _setup: _SetupCallback[_RecordT] | None + _max_inactive_time: float + _in_use: asyncio.Future[None] | None + _inactive_callback: asyncio.TimerHandle | None + _timeout: float | None + _generation: int | None + + def __init__( + self, + pool: Pool[_RecordT], + *, + max_queries: int, + setup: _SetupCallback[_RecordT] | None, + max_inactive_time: float + ) -> None: self._pool = pool self._con = None @@ -109,17 +86,17 @@ def __init__(self, pool, *, max_queries, setup, max_inactive_time): self._max_inactive_time = max_inactive_time self._setup = setup self._inactive_callback = None - self._in_use = None # type: asyncio.Future + self._in_use = None self._timeout = None self._generation = None - def is_connected(self): + def is_connected(self) -> bool: return self._con is not None and not self._con.is_closed() - def is_idle(self): + def is_idle(self) -> bool: return not self._in_use - async def connect(self): + async def connect(self) -> None: if self._con is not None: raise exceptions.InternalClientError( 'PoolConnectionHolder.connect() called while another ' @@ -130,7 +107,9 @@ async def connect(self): self._maybe_cancel_inactive_callback() self._setup_inactive_callback() - async def acquire(self) -> PoolConnectionProxy: + async def acquire( + self + ) -> pool_connection_proxy.PoolConnectionProxy[_RecordT]: if self._con is None or self._con.is_closed(): self._con = None await self.connect() @@ -142,9 +121,14 @@ async def acquire(self) -> PoolConnectionProxy: self._con = None await self.connect() + if typing.TYPE_CHECKING: + assert self._con is not None + self._maybe_cancel_inactive_callback() - self._proxy = proxy = PoolConnectionProxy(self, self._con) + self._proxy = proxy = pool_connection_proxy.PoolConnectionProxy( + self, self._con + ) if self._setup is not None: try: @@ -167,12 +151,15 @@ async def acquire(self) -> PoolConnectionProxy: return proxy - async def release(self, timeout): + async def release(self, timeout: float | None) -> None: if self._in_use is None: raise exceptions.InternalClientError( 'PoolConnectionHolder.release() called on ' 'a free connection holder') + if typing.TYPE_CHECKING: + assert self._con is not None + if self._con.is_closed(): # When closing, pool connections perform the necessary # cleanup, so we don't have to do anything else here. @@ -225,25 +212,25 @@ async def release(self, timeout): # Rearm the connection inactivity timer. self._setup_inactive_callback() - async def wait_until_released(self): + async def wait_until_released(self) -> None: if self._in_use is None: return else: await self._in_use - async def close(self): + async def close(self) -> None: if self._con is not None: # Connection.close() will call _release_on_close() to # finish holder cleanup. await self._con.close() - def terminate(self): + def terminate(self) -> None: if self._con is not None: # Connection.terminate() will call _release_on_close() to # finish holder cleanup. self._con.terminate() - def _setup_inactive_callback(self): + def _setup_inactive_callback(self) -> None: if self._inactive_callback is not None: raise exceptions.InternalClientError( 'pool connection inactivity timer already exists') @@ -252,12 +239,12 @@ def _setup_inactive_callback(self): self._inactive_callback = self._pool._loop.call_later( self._max_inactive_time, self._deactivate_inactive_connection) - def _maybe_cancel_inactive_callback(self): + def _maybe_cancel_inactive_callback(self) -> None: if self._inactive_callback is not None: self._inactive_callback.cancel() self._inactive_callback = None - def _deactivate_inactive_connection(self): + def _deactivate_inactive_connection(self) -> None: if self._in_use is not None: raise exceptions.InternalClientError( 'attempting to deactivate an acquired connection') @@ -271,12 +258,12 @@ def _deactivate_inactive_connection(self): # so terminate() above will not call the below. self._release_on_close() - def _release_on_close(self): + def _release_on_close(self) -> None: self._maybe_cancel_inactive_callback() self._release() self._con = None - def _release(self): + def _release(self) -> None: """Release this connection holder.""" if self._in_use is None: # The holder is not checked out. @@ -292,11 +279,14 @@ def _release(self): self._proxy._detach() self._proxy = None + if typing.TYPE_CHECKING: + assert self._pool._queue is not None + # Put ourselves back to the pool queue. self._pool._queue.put_nowait(self) -class Pool: +class Pool(typing.Generic[_RecordT]): """A connection pool. Connection pool can be used to manage a set of connections to the database. @@ -315,17 +305,42 @@ class Pool: '_setup', '_max_queries', '_max_inactive_connection_lifetime' ) - def __init__(self, *connect_args, - min_size, - max_size, - max_queries, - max_inactive_connection_lifetime, - setup, - init, - loop, - connection_class, - record_class, - **connect_kwargs): + _queue: asyncio.LifoQueue[PoolConnectionHolder[_RecordT]] | None + _loop: asyncio.AbstractEventLoop + _minsize: int + _maxsize: int + _init: _InitCallback[_RecordT] | None + _connect_args: tuple[str | None] | tuple[()] + _connect_kwargs: dict[str, object] + _working_addr: typing.Tuple[str, int] | str + _working_config: connect_utils._ClientConfiguration | None + _working_params: connect_utils._ConnectionParameters | None + _holders: list[PoolConnectionHolder[_RecordT]] + _initialized: bool + _initializing: bool + _closing: bool + _closed: bool + _connection_class: type[connection.Connection[_RecordT]] + _record_class: type[_RecordT] + _generation: int + _setup: _SetupCallback[_RecordT] | None + _max_queries: int + _max_inactive_connection_lifetime: float + + def __init__( + self, + *connect_args: str | None, + min_size: int, + max_size: int, + max_queries: int, + max_inactive_connection_lifetime: float, + setup: _SetupCallback[_RecordT] | None, + init: _InitCallback[_RecordT] | None, + loop: asyncio.AbstractEventLoop | None, + connection_class: type[_ConnectionT], + record_class: type[_RecordT], + **connect_kwargs: object + ): if len(connect_args) > 1: warnings.warn( @@ -382,7 +397,9 @@ def __init__(self, *connect_args, self._closed = False self._generation = 0 self._init = init - self._connect_args = connect_args + self._connect_args = ( + () if not len(connect_args) else (connect_args[0],) + ) self._connect_kwargs = connect_kwargs self._setup = setup @@ -390,9 +407,9 @@ def __init__(self, *connect_args, self._max_inactive_connection_lifetime = \ max_inactive_connection_lifetime - async def _async__init__(self): + async def _async__init__(self) -> Self | None: if self._initialized: - return + return None if self._initializing: raise exceptions.InterfaceError( 'pool is being initialized in another task') @@ -406,7 +423,7 @@ async def _async__init__(self): self._initializing = False self._initialized = True - async def _initialize(self): + async def _initialize(self) -> None: self._queue = asyncio.LifoQueue(maxsize=self._maxsize) for _ in range(self._maxsize): ch = PoolConnectionHolder( @@ -426,11 +443,11 @@ async def _initialize(self): # Connect the first connection holder in the queue so that # any connection issues are visible early. - first_ch = self._holders[-1] # type: PoolConnectionHolder + first_ch: PoolConnectionHolder[_RecordT] = self._holders[-1] await first_ch.connect() if self._minsize > 1: - connect_tasks = [] + connect_tasks: list[compat.Awaitable[None]] = [] for i, ch in enumerate(reversed(self._holders[:-1])): # `minsize - 1` because we already have first_ch if i >= self._minsize - 1: @@ -439,42 +456,44 @@ async def _initialize(self): await asyncio.gather(*connect_tasks) - def is_closing(self): + def is_closing(self) -> bool: """Return ``True`` if the pool is closing or is closed. .. versionadded:: 0.28.0 """ return self._closed or self._closing - def get_size(self): + def get_size(self) -> int: """Return the current number of connections in this pool. .. versionadded:: 0.25.0 """ return sum(h.is_connected() for h in self._holders) - def get_min_size(self): + def get_min_size(self) -> int: """Return the minimum number of connections in this pool. .. versionadded:: 0.25.0 """ return self._minsize - def get_max_size(self): + def get_max_size(self) -> int: """Return the maximum allowed number of connections in this pool. .. versionadded:: 0.25.0 """ return self._maxsize - def get_idle_size(self): + def get_idle_size(self) -> int: """Return the current number of idle connections in this pool. .. versionadded:: 0.25.0 """ return sum(h.is_connected() and h.is_idle() for h in self._holders) - def set_connect_args(self, dsn=None, **connect_kwargs): + def set_connect_args( + self, dsn: str | None = None, **connect_kwargs: object + ) -> None: r"""Set the new connection arguments for this pool. The new connection arguments will be used for all subsequent @@ -495,16 +514,16 @@ def set_connect_args(self, dsn=None, **connect_kwargs): .. versionadded:: 0.16.0 """ - self._connect_args = [dsn] + self._connect_args = (dsn,) self._connect_kwargs = connect_kwargs - async def _get_new_connection(self): - con = await connection.connect( + async def _get_new_connection(self) -> connection.Connection[_RecordT]: + con: connection.Connection[_RecordT] = await connection.connect( *self._connect_args, loop=self._loop, connection_class=self._connection_class, record_class=self._record_class, - **self._connect_kwargs, + **typing.cast(typing.Any, self._connect_kwargs), ) if self._init is not None: @@ -526,7 +545,9 @@ async def _get_new_connection(self): return con - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute( + self, query: str, *args: object, timeout: float | None = None + ) -> str: """Execute an SQL command (or commands). Pool performs this operation using one of its connections. Other than @@ -538,7 +559,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: async with self.acquire() as con: return await con.execute(query, *args, timeout=timeout) - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany( + self, + command: str, + args: compat.Iterable[compat.Sequence[object]], + *, + timeout: float | None = None, + ) -> None: """Execute an SQL *command* for each sequence of arguments in *args*. Pool performs this operation using one of its connections. Other than @@ -551,13 +578,43 @@ async def executemany(self, command: str, args, *, timeout: float=None): async with self.acquire() as con: return await con.executemany(command, args, timeout=timeout) + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> list[_RecordT]: + ... + + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> list[_OtherRecordT]: + ... + + @typing.overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> list[_RecordT] | list[_OtherRecordT]: + ... + async def fetch( self, - query, - *args, - timeout=None, - record_class=None - ) -> list: + query: str, + *args: object, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None, + ) -> list[_RecordT] | list[_OtherRecordT]: """Run a query and return the results as a list of :class:`Record`. Pool performs this operation using one of its connections. Other than @@ -574,7 +631,13 @@ async def fetch( record_class=record_class ) - async def fetchval(self, query, *args, column=0, timeout=None): + async def fetchval( + self, + query: str, + *args: object, + column: int = 0, + timeout: float | None = None, + ) -> typing.Any: """Run a query and return a value in the first row. Pool performs this operation using one of its connections. Other than @@ -588,7 +651,43 @@ async def fetchval(self, query, *args, column=0, timeout=None): return await con.fetchval( query, *args, column=column, timeout=timeout) - async def fetchrow(self, query, *args, timeout=None, record_class=None): + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> _RecordT | None: + ... + + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> _OtherRecordT | None: + ... + + @typing.overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> _RecordT | _OtherRecordT | None: + ... + + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = None, + record_class: type[_OtherRecordT] | None = None, + ) -> _RecordT | _OtherRecordT | None: """Run a query and return the first row. Pool performs this operation using one of its connections. Other than @@ -607,22 +706,22 @@ async def fetchrow(self, query, *args, timeout=None, record_class=None): async def copy_from_table( self, - table_name, + table_name: str, *, - output, - columns=None, - schema_name=None, - timeout=None, - format=None, - oids=None, - delimiter=None, - null=None, - header=None, - quote=None, - escape=None, - force_quote=None, - encoding=None - ): + output: connection._OutputType, + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + format: connection._CopyFormat | None = None, + oids: int | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + ) -> str: """Copy table contents to a file or file-like object. Pool performs this operation using one of its connections. Other than @@ -652,20 +751,20 @@ async def copy_from_table( async def copy_from_query( self, - query, - *args, - output, - timeout=None, - format=None, - oids=None, - delimiter=None, - null=None, - header=None, - quote=None, - escape=None, - force_quote=None, - encoding=None - ): + query: str, + *args: object, + output: connection._OutputType, + timeout: float | None = None, + format: connection._CopyFormat | None = None, + oids: int | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + ) -> str: """Copy the results of a query to a file or file-like object. Pool performs this operation using one of its connections. Other than @@ -694,26 +793,26 @@ async def copy_from_query( async def copy_to_table( self, - table_name, + table_name: str, *, - source, - columns=None, - schema_name=None, - timeout=None, - format=None, - oids=None, - freeze=None, - delimiter=None, - null=None, - header=None, - quote=None, - escape=None, - force_quote=None, - force_not_null=None, - force_null=None, - encoding=None, - where=None - ): + source: connection._SourceType, + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + format: connection._CopyFormat | None = None, + oids: int | None = None, + freeze: bool | None = None, + delimiter: str | None = None, + null: str | None = None, + header: bool | None = None, + quote: str | None = None, + escape: str | None = None, + force_quote: bool | compat.Iterable[str] | None = None, + force_not_null: bool | compat.Iterable[str] | None = None, + force_null: bool | compat.Iterable[str] | None = None, + encoding: str | None = None, + where: str | None = None, + ) -> str: """Copy data to the specified table. Pool performs this operation using one of its connections. Other than @@ -747,14 +846,16 @@ async def copy_to_table( async def copy_records_to_table( self, - table_name, + table_name: str, *, - records, - columns=None, - schema_name=None, - timeout=None, - where=None - ): + records: compat.Iterable[ + compat.Sequence[object] + ] | compat.AsyncIterable[compat.Sequence[object]], + columns: compat.Iterable[str] | None = None, + schema_name: str | None = None, + timeout: float | None = None, + where: str | None = None, + ) -> str: """Copy a list of records to the specified table using binary COPY. Pool performs this operation using one of its connections. Other than @@ -774,7 +875,9 @@ async def copy_records_to_table( where=where ) - def acquire(self, *, timeout=None): + def acquire( + self, *, timeout: float | None = None + ) -> PoolAcquireContext[_RecordT]: """Acquire a database connection from the pool. :param float timeout: A timeout for acquiring a Connection. @@ -799,11 +902,18 @@ def acquire(self, *, timeout=None): """ return PoolAcquireContext(self, timeout) - async def _acquire(self, timeout): - async def _acquire_impl(): - ch = await self._queue.get() # type: PoolConnectionHolder + async def _acquire( + self, timeout: float | None + ) -> pool_connection_proxy.PoolConnectionProxy[_RecordT]: + async def _acquire_impl() -> pool_connection_proxy.PoolConnectionProxy[ + _RecordT + ]: + if typing.TYPE_CHECKING: + assert self._queue is not None + + ch: PoolConnectionHolder[_RecordT] = await self._queue.get() try: - proxy = await ch.acquire() # type: PoolConnectionProxy + proxy = await ch.acquire() except (Exception, asyncio.CancelledError): self._queue.put_nowait(ch) raise @@ -823,7 +933,12 @@ async def _acquire_impl(): return await compat.wait_for( _acquire_impl(), timeout=timeout) - async def release(self, connection, *, timeout=None): + async def release( + self, + connection: pool_connection_proxy.PoolConnectionProxy[_RecordT], + *, + timeout: float | None = None, + ) -> None: """Release a database connection back to the pool. :param Connection connection: @@ -836,8 +951,8 @@ async def release(self, connection, *, timeout=None): .. versionchanged:: 0.14.0 Added the *timeout* parameter. """ - if (type(connection) is not PoolConnectionProxy or - connection._holder._pool is not self): + if (type(connection) is not pool_connection_proxy.PoolConnectionProxy + or connection._holder._pool is not self): raise exceptions.InterfaceError( 'Pool.release() received invalid connection: ' '{connection!r} is not a member of this pool'.format( @@ -861,7 +976,7 @@ async def release(self, connection, *, timeout=None): # pool properly. return await asyncio.shield(ch.release(timeout)) - async def close(self): + async def close(self) -> None: """Attempt to gracefully close all connections in the pool. Wait until all pool connections are released, close them and @@ -906,13 +1021,13 @@ async def close(self): self._closed = True self._closing = False - def _warn_on_long_close(self): - logger.warning('Pool.close() is taking over 60 seconds to complete. ' - 'Check if you have any unreleased connections left. ' - 'Use asyncio.wait_for() to set a timeout for ' - 'Pool.close().') + def _warn_on_long_close(self) -> None: + _logger.warning('Pool.close() is taking over 60 seconds to complete. ' + 'Check if you have any unreleased connections left. ' + 'Use asyncio.wait_for() to set a timeout for ' + 'Pool.close().') - def terminate(self): + def terminate(self) -> None: """Terminate all connections in the pool.""" if self._closed: return @@ -921,7 +1036,7 @@ def terminate(self): ch.terminate() self._closed = True - async def expire_connections(self): + async def expire_connections(self) -> None: """Expire all currently open connections. Cause all currently open connections to get replaced on the @@ -931,7 +1046,7 @@ async def expire_connections(self): """ self._generation += 1 - def _check_init(self): + def _check_init(self) -> None: if not self._initialized: if self._initializing: raise exceptions.InterfaceError( @@ -942,67 +1057,142 @@ def _check_init(self): if self._closed: raise exceptions.InterfaceError('pool is closed') - def _drop_statement_cache(self): + def _drop_statement_cache(self) -> None: # Drop statement cache for all connections in the pool. for ch in self._holders: if ch._con is not None: ch._con._drop_local_statement_cache() - def _drop_type_cache(self): + def _drop_type_cache(self) -> None: # Drop type codec cache for all connections in the pool. for ch in self._holders: if ch._con is not None: ch._con._drop_local_type_cache() - def __await__(self): + def __await__(self) -> compat.Generator[typing.Any, None, Self | None]: return self._async__init__().__await__() - async def __aenter__(self): + async def __aenter__(self) -> Self: await self._async__init__() return self - async def __aexit__(self, *exc): + async def __aexit__(self, *exc: object) -> None: await self.close() -class PoolAcquireContext: +class PoolAcquireContext(typing.Generic[_RecordT]): __slots__ = ('timeout', 'connection', 'done', 'pool') - def __init__(self, pool, timeout): + timeout: float | None + connection: pool_connection_proxy.PoolConnectionProxy[_RecordT] | None + done: bool + pool: Pool[_RecordT] + + def __init__(self, pool: Pool[_RecordT], timeout: float | None) -> None: self.pool = pool self.timeout = timeout self.connection = None self.done = False - async def __aenter__(self): + async def __aenter__( + self + ) -> pool_connection_proxy.PoolConnectionProxy[_RecordT]: if self.connection is not None or self.done: raise exceptions.InterfaceError('a connection is already acquired') self.connection = await self.pool._acquire(self.timeout) return self.connection - async def __aexit__(self, *exc): + async def __aexit__(self, *exc: object) -> None: self.done = True con = self.connection self.connection = None + if typing.TYPE_CHECKING: + assert con is not None await self.pool.release(con) - def __await__(self): + def __await__(self) -> compat.Generator[ + typing.Any, None, pool_connection_proxy.PoolConnectionProxy[_RecordT] + ]: self.done = True return self.pool._acquire(self.timeout).__await__() -def create_pool(dsn=None, *, - min_size=10, - max_size=10, - max_queries=50000, - max_inactive_connection_lifetime=300.0, - setup=None, - init=None, - loop=None, - connection_class=connection.Connection, - record_class=protocol.Record, - **connect_kwargs): +@typing.overload +def create_pool( + dsn: str | None = ..., + *, + min_size: int = ..., + max_size: int = ..., + max_queries: int = ..., + max_inactive_connection_lifetime: float = ..., + setup: _SetupCallback[_RecordT] | None = ..., + init: _InitCallback[_RecordT] | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + connection_class: type[connection.Connection[_RecordT]] = ..., + record_class: type[_RecordT], + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + server_settings: dict[str, str] | None = ..., +) -> Pool[_RecordT]: + ... + + +@typing.overload +def create_pool( + dsn: str | None = ..., + *, + min_size: int = ..., + max_size: int = ..., + max_queries: int = ..., + max_inactive_connection_lifetime: float = ..., + setup: _SetupCallback[protocol.Record] | None = ..., + init: _InitCallback[protocol.Record] | None = ..., + loop: asyncio.AbstractEventLoop | None = ..., + connection_class: type[connection.Connection[protocol.Record]] = ..., + host: connect_utils.HostType | None = ..., + port: connect_utils.PortType | None = ..., + user: str | None = ..., + password: connect_utils.PasswordType | None = ..., + passfile: str | None = ..., + database: str | None = ..., + timeout: float = ..., + statement_cache_size: int = ..., + max_cached_statement_lifetime: int = ..., + max_cacheable_statement_size: int = ..., + command_timeout: float | None = ..., + ssl: connect_utils.SSLType | None = ..., + server_settings: dict[str, str] | None = ..., +) -> Pool[protocol.Record]: + ... + + +def create_pool( + dsn: str | None = None, + *, + min_size: int = 10, + max_size: int = 10, + max_queries: int = 50000, + max_inactive_connection_lifetime: float = 300.0, + setup: _SetupCallback[typing.Any] | None = None, + init: _InitCallback[typing.Any] | None = None, + loop: asyncio.AbstractEventLoop | None = None, + connection_class: type[ + connection.Connection[typing.Any] + ] = connection.Connection, + record_class: type[protocol.Record] | type[_RecordT] = protocol.Record, + **connect_kwargs: typing.Any +) -> Pool[typing.Any]: r"""Create a connection pool. Can be used either with an ``async with`` block: diff --git a/asyncpg/pool_connection_proxy.py b/asyncpg/pool_connection_proxy.py new file mode 100644 index 00000000..b4d2e4fc --- /dev/null +++ b/asyncpg/pool_connection_proxy.py @@ -0,0 +1,91 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +from __future__ import annotations + +import functools +import inspect +import typing + +from . import connection +from . import exceptions + +if typing.TYPE_CHECKING: + from . import pool + from . import protocol + + +_RecordT = typing.TypeVar('_RecordT', bound='protocol.Record') + + +class PoolConnectionProxyMeta(type): + + def __new__(mcls, name, bases, dct, *, wrap=False): + if wrap: + for attrname in dir(connection.Connection): + if attrname.startswith('_') or attrname in dct: + continue + + meth = getattr(connection.Connection, attrname) + if not inspect.isfunction(meth): + continue + + wrapper = mcls._wrap_connection_method(attrname) + wrapper = functools.update_wrapper(wrapper, meth) + dct[attrname] = wrapper + + if '__doc__' not in dct: + dct['__doc__'] = connection.Connection.__doc__ + + return super().__new__(mcls, name, bases, dct) + + @staticmethod + def _wrap_connection_method(meth_name): + def call_con_method(self, *args, **kwargs): + # This method will be owned by PoolConnectionProxy class. + if self._con is None: + raise exceptions.InterfaceError( + 'cannot call Connection.{}(): ' + 'connection has been released back to the pool'.format( + meth_name)) + + meth = getattr(self._con.__class__, meth_name) + return meth(self._con, *args, **kwargs) + + return call_con_method + + +class PoolConnectionProxy(connection._ConnectionProxy[_RecordT], + metaclass=PoolConnectionProxyMeta, + wrap=True): + + __slots__ = ('_con', '_holder') + + def __init__(self, holder: pool.PoolConnectionHolder, + con: connection.Connection[_RecordT]): + self._con = con + self._holder = holder + con._set_proxy(self) + + def __getattr__(self, attr): + # Proxy all unresolved attributes to the wrapped Connection object. + return getattr(self._con, attr) + + def _detach(self) -> connection.Connection[_RecordT]: + if self._con is None: + return + + con, self._con = self._con, None + con._set_proxy(None) + return con + + def __repr__(self): + if self._con is None: + return '<{classname} [released] {id:#x}>'.format( + classname=self.__class__.__name__, id=id(self)) + else: + return '<{classname} {con!r} {id:#x}>'.format( + classname=self.__class__.__name__, con=self._con, id=id(self)) diff --git a/asyncpg/pool_connection_proxy.pyi b/asyncpg/pool_connection_proxy.pyi new file mode 100644 index 00000000..cdb03af9 --- /dev/null +++ b/asyncpg/pool_connection_proxy.pyi @@ -0,0 +1,284 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +import contextlib +from collections.abc import ( + AsyncIterable, + Callable, + Iterable, + Iterator, + Sequence, +) +from typing import Any, TypeVar, overload + +from . import connection +from . import cursor +from . import pool +from . import prepared_stmt +from . import protocol +from . import transaction +from . import types +from .protocol import protocol as _cprotocol + +_RecordT = TypeVar('_RecordT', bound=protocol.Record) +_OtherRecordT = TypeVar('_OtherRecordT', bound=protocol.Record) + +class PoolConnectionProxyMeta(type): ... + +class PoolConnectionProxy( + connection._ConnectionProxy[_RecordT], metaclass=PoolConnectionProxyMeta +): + __slots__ = ('_con', '_holder') + _con: connection.Connection[_RecordT] + _holder: pool.PoolConnectionHolder[_RecordT] + def __init__( + self, + holder: pool.PoolConnectionHolder[_RecordT], + con: connection.Connection[_RecordT], + ) -> None: ... + def _detach(self) -> connection.Connection[_RecordT]: ... + + # The following methods are copied from Connection + async def add_listener( + self, channel: str, callback: connection.Listener + ) -> None: ... + async def remove_listener( + self, channel: str, callback: connection.Listener + ) -> None: ... + def add_log_listener(self, callback: connection.LogListener) -> None: ... + def remove_log_listener(self, callback: connection.LogListener) -> None: ... + def add_termination_listener( + self, callback: connection.TerminationListener + ) -> None: ... + def remove_termination_listener( + self, callback: connection.TerminationListener + ) -> None: ... + def add_query_logger(self, callback: connection.QueryLogger) -> None: ... + def remove_query_logger(self, callback: connection.QueryLogger) -> None: ... + def get_server_pid(self) -> int: ... + def get_server_version(self) -> types.ServerVersion: ... + def get_settings(self) -> _cprotocol.ConnectionSettings: ... + def transaction( + self, + *, + isolation: transaction.IsolationLevels | None = ..., + readonly: bool = ..., + deferrable: bool = ..., + ) -> transaction.Transaction: ... + def is_in_transaction(self) -> bool: ... + async def execute( + self, query: str, *args: object, timeout: float | None = ... + ) -> str: ... + async def executemany( + self, + command: str, + args: Iterable[Sequence[object]], + *, + timeout: float | None = ..., + ) -> None: ... + @overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: None = ..., + ) -> cursor.CursorFactory[_RecordT]: ... + @overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> cursor.CursorFactory[_OtherRecordT]: ... + @overload + def cursor( + self, + query: str, + *args: object, + prefetch: int | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> cursor.CursorFactory[_RecordT] | cursor.CursorFactory[_OtherRecordT]: ... + @overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: None = ..., + ) -> prepared_stmt.PreparedStatement[_RecordT]: ... + @overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> prepared_stmt.PreparedStatement[_OtherRecordT]: ... + @overload + async def prepare( + self, + query: str, + *, + name: str | None = ..., + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> ( + prepared_stmt.PreparedStatement[_RecordT] + | prepared_stmt.PreparedStatement[_OtherRecordT] + ): ... + @overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> list[_RecordT]: ... + @overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> list[_OtherRecordT]: ... + @overload + async def fetch( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> list[_RecordT] | list[_OtherRecordT]: ... + async def fetchval( + self, + query: str, + *args: object, + column: int = ..., + timeout: float | None = ..., + ) -> Any: ... + @overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: None = ..., + ) -> _RecordT | None: ... + @overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT], + ) -> _OtherRecordT | None: ... + @overload + async def fetchrow( + self, + query: str, + *args: object, + timeout: float | None = ..., + record_class: type[_OtherRecordT] | None, + ) -> _RecordT | _OtherRecordT | None: ... + async def copy_from_table( + self, + table_name: str, + *, + output: connection._OutputType, + columns: Iterable[str] | None = ..., + schema_name: str | None = ..., + timeout: float | None = ..., + format: connection._CopyFormat | None = ..., + oids: int | None = ..., + delimiter: str | None = ..., + null: str | None = ..., + header: bool | None = ..., + quote: str | None = ..., + escape: str | None = ..., + force_quote: bool | Iterable[str] | None = ..., + encoding: str | None = ..., + ) -> str: ... + async def copy_from_query( + self, + query: str, + *args: object, + output: connection._OutputType, + timeout: float | None = ..., + format: connection._CopyFormat | None = ..., + oids: int | None = ..., + delimiter: str | None = ..., + null: str | None = ..., + header: bool | None = ..., + quote: str | None = ..., + escape: str | None = ..., + force_quote: bool | Iterable[str] | None = ..., + encoding: str | None = ..., + ) -> str: ... + async def copy_to_table( + self, + table_name: str, + *, + source: connection._SourceType, + columns: Iterable[str] | None = ..., + schema_name: str | None = ..., + timeout: float | None = ..., + format: connection._CopyFormat | None = ..., + oids: int | None = ..., + freeze: bool | None = ..., + delimiter: str | None = ..., + null: str | None = ..., + header: bool | None = ..., + quote: str | None = ..., + escape: str | None = ..., + force_quote: bool | Iterable[str] | None = ..., + force_not_null: bool | Iterable[str] | None = ..., + force_null: bool | Iterable[str] | None = ..., + encoding: str | None = ..., + where: str | None = ..., + ) -> str: ... + async def copy_records_to_table( + self, + table_name: str, + *, + records: Iterable[Sequence[object]] | AsyncIterable[Sequence[object]], + columns: Iterable[str] | None = ..., + schema_name: str | None = ..., + timeout: float | None = ..., + where: str | None = ..., + ) -> str: ... + async def set_type_codec( + self, + typename: str, + *, + schema: str = ..., + encoder: Callable[[Any], Any], + decoder: Callable[[Any], Any], + format: str = ..., + ) -> None: ... + async def reset_type_codec(self, typename: str, *, schema: str = ...) -> None: ... + async def set_builtin_type_codec( + self, + typename: str, + *, + schema: str = ..., + codec_name: str, + format: str | None = ..., + ) -> None: ... + def is_closed(self) -> bool: ... + async def close(self, *, timeout: float | None = ...) -> None: ... + def terminate(self) -> None: ... + async def reset(self, *, timeout: float | None = ...) -> None: ... + async def reload_schema_state(self) -> None: ... + @contextlib.contextmanager + def query_logger(self, callback: connection.QueryLogger) -> Iterator[None]: ... diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 8e241d67..f49163a2 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -4,20 +4,52 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import json +import typing from . import connresource from . import cursor from . import exceptions +if typing.TYPE_CHECKING: + from .protocol import protocol as _cprotocol + from . import compat + from . import connection as _connection + from . import types -class PreparedStatement(connresource.ConnectionResource): + +_RecordT = typing.TypeVar('_RecordT', bound='_cprotocol.Record') +_T = typing.TypeVar('_T') +_T_co = typing.TypeVar('_T_co', covariant=True) + + +class _Executor(typing.Protocol[_T_co]): + def __call__( + self, __protocol: _cprotocol.BaseProtocol[typing.Any] + ) -> compat.Awaitable[_T_co]: + ... + + +class PreparedStatement( + connresource.ConnectionResource, + typing.Generic[_RecordT] +): """A representation of a prepared statement.""" __slots__ = ('_state', '_query', '_last_status') - def __init__(self, connection, query, state): + _state: _cprotocol.PreparedStatementState[_RecordT] + _query: str + _last_status: bytes | None + + def __init__( + self, + connection: _connection.Connection[typing.Any], + query: str, + state: _cprotocol.PreparedStatementState[_RecordT], + ) -> None: super().__init__(connection) self._state = state self._query = query @@ -44,7 +76,7 @@ def get_query(self) -> str: return self._query @connresource.guarded - def get_statusmsg(self) -> str: + def get_statusmsg(self) -> str | None: """Return the status of the executed command. Example:: @@ -58,7 +90,7 @@ def get_statusmsg(self) -> str: return self._last_status.decode() @connresource.guarded - def get_parameters(self): + def get_parameters(self) -> tuple[types.Type, ...]: """Return a description of statement parameters types. :return: A tuple of :class:`asyncpg.types.Type`. @@ -75,7 +107,7 @@ def get_parameters(self): return self._state._get_parameters() @connresource.guarded - def get_attributes(self): + def get_attributes(self) -> tuple[types.Attribute, ...]: """Return a description of relation attributes (columns). :return: A tuple of :class:`asyncpg.types.Attribute`. @@ -100,8 +132,8 @@ def get_attributes(self): return self._state._get_attributes() @connresource.guarded - def cursor(self, *args, prefetch=None, - timeout=None) -> cursor.CursorFactory: + def cursor(self, *args: object, prefetch: int | None = None, + timeout: float | None = None) -> cursor.CursorFactory[_RecordT]: """Return a *cursor factory* for the prepared statement. :param args: Query arguments. @@ -122,7 +154,9 @@ def cursor(self, *args, prefetch=None, ) @connresource.guarded - async def explain(self, *args, analyze=False): + async def explain( + self, *args: object, analyze: bool = False + ) -> typing.Any: """Return the execution plan of the statement. :param args: Query arguments. @@ -164,7 +198,9 @@ async def explain(self, *args, analyze=False): return json.loads(data) @connresource.guarded - async def fetch(self, *args, timeout=None): + async def fetch( + self, *args: object, timeout: float | None = None + ) -> list[_RecordT]: r"""Execute the statement and return a list of :class:`Record` objects. :param str query: Query text @@ -177,7 +213,9 @@ async def fetch(self, *args, timeout=None): return data @connresource.guarded - async def fetchval(self, *args, column=0, timeout=None): + async def fetchval( + self, *args: object, column: int = 0, timeout: float | None = None + ) -> typing.Any: """Execute the statement and return a value in the first row. :param args: Query arguments. @@ -196,7 +234,9 @@ async def fetchval(self, *args, column=0, timeout=None): return data[0][column] @connresource.guarded - async def fetchrow(self, *args, timeout=None): + async def fetchrow( + self, *args: object, timeout: float | None = None + ) -> _RecordT | None: """Execute the statement and return the first row. :param str query: Query text @@ -211,7 +251,12 @@ async def fetchrow(self, *args, timeout=None): return data[0] @connresource.guarded - async def executemany(self, args, *, timeout: float=None): + async def executemany( + self, + args: compat.Iterable[compat.Sequence[object]], + *, + timeout: float | None = None + ) -> None: """Execute the statement for each sequence of arguments in *args*. :param args: An iterable containing sequences of arguments. @@ -224,7 +269,7 @@ async def executemany(self, args, *, timeout: float=None): lambda protocol: protocol.bind_execute_many( self._state, args, '', timeout)) - async def __do_execute(self, executor): + async def __do_execute(self, executor: _Executor[_T]) -> _T: protocol = self._connection._protocol try: return await executor(protocol) @@ -237,23 +282,28 @@ async def __do_execute(self, executor): self._state.mark_closed() raise - async def __bind_execute(self, args, limit, timeout): - data, status, _ = await self.__do_execute( - lambda protocol: protocol.bind_execute( - self._state, args, '', limit, True, timeout)) + async def __bind_execute( + self, args: compat.Sequence[object], limit: int, timeout: float | None + ) -> list[_RecordT]: + executor: _Executor[ + tuple[list[_RecordT], bytes, bool] + ] = lambda protocol: protocol.bind_execute( + self._state, args, '', limit, True, timeout + ) + data, status, _ = await self.__do_execute(executor) self._last_status = status return data - def _check_open(self, meth_name): + def _check_open(self, meth_name: str) -> None: if self._state.closed: raise exceptions.InterfaceError( 'cannot call PreparedStmt.{}(): ' 'the prepared statement is closed'.format(meth_name)) - def _check_conn_validity(self, meth_name): + def _check_conn_validity(self, meth_name: str) -> None: self._check_open(meth_name) super()._check_conn_validity(meth_name) - def __del__(self): + def __del__(self) -> None: self._state.detach() self._connection._maybe_gc_stmt(self._state) diff --git a/asyncpg/protocol/__init__.py b/asyncpg/protocol/__init__.py index 8b3e06a0..af9287bd 100644 --- a/asyncpg/protocol/__init__.py +++ b/asyncpg/protocol/__init__.py @@ -6,4 +6,6 @@ # flake8: NOQA +from __future__ import annotations + from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP diff --git a/asyncpg/protocol/protocol.pyi b/asyncpg/protocol/protocol.pyi new file mode 100644 index 00000000..ea468e6d --- /dev/null +++ b/asyncpg/protocol/protocol.pyi @@ -0,0 +1,300 @@ +import asyncio +import asyncio.protocols +import hmac +from codecs import CodecInfo +from collections.abc import Callable, Iterable, Iterator, Sequence +from hashlib import md5, sha256 +from typing import ( + Any, + ClassVar, + Final, + Generic, + Literal, + NewType, + TypeVar, + final, + overload, +) +from typing_extensions import TypeAlias + +import asyncpg.pgproto.pgproto + +from ..connect_utils import _ConnectionParameters +from ..pgproto.pgproto import WriteBuffer +from ..types import Attribute, Type + +_T = TypeVar('_T') +_Record = TypeVar('_Record', bound=Record) +_OtherRecord = TypeVar('_OtherRecord', bound=Record) +_PreparedStatementState = TypeVar( + '_PreparedStatementState', bound=PreparedStatementState[Any] +) + +_NoTimeoutType = NewType('_NoTimeoutType', object) +_TimeoutType: TypeAlias = float | None | _NoTimeoutType + +BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]] +BUILTIN_TYPE_OID_MAP: Final[dict[int, str]] +NO_TIMEOUT: Final[_NoTimeoutType] + +hashlib_md5 = md5 + +@final +class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext): + __pyx_vtable__: Any + def __init__(self, conn_key: object) -> None: ... + def add_python_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typeinfos: Iterable[object], + typekind: str, + encoder: Callable[[Any], Any], + decoder: Callable[[Any], Any], + format: object, + ) -> Any: ... + def clear_type_cache(self) -> None: ... + def get_data_codec( + self, oid: int, format: object = ..., ignore_custom_codec: bool = ... + ) -> Any: ... + def get_text_codec(self) -> CodecInfo: ... + def register_data_types(self, types: Iterable[object]) -> None: ... + def remove_python_codec( + self, typeoid: int, typename: str, typeschema: str + ) -> None: ... + def set_builtin_type_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + alias_to: str, + format: object = ..., + ) -> Any: ... + def __getattr__(self, name: str) -> Any: ... + def __reduce__(self) -> Any: ... + +@final +class PreparedStatementState(Generic[_Record]): + closed: bool + prepared: bool + name: str + query: str + refs: int + record_class: type[_Record] + ignore_custom_codec: bool + __pyx_vtable__: Any + def __init__( + self, + name: str, + query: str, + protocol: BaseProtocol[Any], + record_class: type[_Record], + ignore_custom_codec: bool, + ) -> None: ... + def _get_parameters(self) -> tuple[Type, ...]: ... + def _get_attributes(self) -> tuple[Attribute, ...]: ... + def _init_types(self) -> set[int]: ... + def _init_codecs(self) -> None: ... + def attach(self) -> None: ... + def detach(self) -> None: ... + def mark_closed(self) -> None: ... + def mark_unprepared(self) -> None: ... + def __reduce__(self) -> Any: ... + +class CoreProtocol: + backend_pid: Any + backend_secret: Any + __pyx_vtable__: Any + def __init__(self, con_params: _ConnectionParameters) -> None: ... + def is_in_transaction(self) -> bool: ... + def __reduce__(self) -> Any: ... + +class BaseProtocol(CoreProtocol, Generic[_Record]): + queries_count: Any + is_ssl: bool + __pyx_vtable__: Any + def __init__( + self, + addr: object, + connected_fut: object, + con_params: _ConnectionParameters, + record_class: type[_Record], + loop: object, + ) -> None: ... + def set_connection(self, connection: object) -> None: ... + def get_server_pid(self, *args: object, **kwargs: object) -> int: ... + def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ... + def get_record_class(self) -> type[_Record]: ... + def abort(self) -> None: ... + async def bind( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + timeout: _TimeoutType, + ) -> Any: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: Literal[False], + timeout: _TimeoutType, + ) -> list[_OtherRecord]: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: Literal[True], + timeout: _TimeoutType, + ) -> tuple[list[_OtherRecord], bytes, bool]: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: bool, + timeout: _TimeoutType, + ) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ... + async def bind_execute_many( + self, + state: PreparedStatementState[_OtherRecord], + args: Iterable[Sequence[object]], + portal_name: str, + timeout: _TimeoutType, + ) -> None: ... + async def close(self, timeout: _TimeoutType) -> None: ... + def _get_timeout(self, timeout: _TimeoutType) -> float | None: ... + def _is_cancelling(self) -> bool: ... + async def _wait_for_cancellation(self) -> None: ... + async def close_statement( + self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType + ) -> Any: ... + async def copy_in(self, *args: object, **kwargs: object) -> str: ... + async def copy_out(self, *args: object, **kwargs: object) -> str: ... + async def execute(self, *args: object, **kwargs: object) -> Any: ... + def is_closed(self, *args: object, **kwargs: object) -> Any: ... + def is_connected(self, *args: object, **kwargs: object) -> Any: ... + def data_received(self, data: object) -> None: ... + def connection_made(self, transport: object) -> None: ... + def connection_lost(self, exc: Exception | None) -> None: ... + def pause_writing(self, *args: object, **kwargs: object) -> Any: ... + @overload + async def prepare( + self, + stmt_name: str, + query: str, + timeout: float | None = ..., + *, + state: _PreparedStatementState, + ignore_custom_codec: bool = ..., + record_class: None, + ) -> _PreparedStatementState: ... + @overload + async def prepare( + self, + stmt_name: str, + query: str, + timeout: float | None = ..., + *, + state: None = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecord], + ) -> PreparedStatementState[_OtherRecord]: ... + async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ... + async def query(self, *args: object, **kwargs: object) -> str: ... + def resume_writing(self, *args: object, **kwargs: object) -> Any: ... + def __reduce__(self) -> Any: ... + +@final +class Codec: + __pyx_vtable__: Any + def __reduce__(self) -> Any: ... + +class DataCodecConfig: + __pyx_vtable__: Any + def __init__(self, cache_key: object) -> None: ... + def add_python_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + typeinfos: Iterable[object], + encoder: Callable[[ConnectionSettings, WriteBuffer, object], object], + decoder: Callable[..., object], + format: object, + xformat: object, + ) -> Any: ... + def add_types(self, types: Iterable[object]) -> Any: ... + def clear_type_cache(self) -> None: ... + def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ... + def remove_python_codec( + self, typeoid: int, typename: str, typeschema: str + ) -> Any: ... + def set_builtin_type_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + alias_to: str, + format: object = ..., + ) -> Any: ... + def __reduce__(self) -> Any: ... + +class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ... + +class Record: + @overload + def get(self, key: str) -> Any | None: ... + @overload + def get(self, key: str, default: _T) -> Any | _T: ... + def items(self) -> Iterator[tuple[str, Any]]: ... + def keys(self) -> Iterator[str]: ... + def values(self) -> Iterator[Any]: ... + @overload + def __getitem__(self, index: str) -> Any: ... + @overload + def __getitem__(self, index: int) -> Any: ... + @overload + def __getitem__(self, index: slice) -> tuple[Any, ...]: ... + def __iter__(self) -> Iterator[Any]: ... + def __contains__(self, x: object) -> bool: ... + def __len__(self) -> int: ... + +class Timer: + def __init__(self, budget: float | None) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, et: object, e: object, tb: object) -> None: ... + def get_remaining_budget(self) -> float: ... + def has_budget_greater_than(self, amount: float) -> bool: ... + +@final +class SCRAMAuthentication: + AUTHENTICATION_METHODS: ClassVar[list[str]] + DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int] + DIGEST = sha256 + REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]] + REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]] + SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]] + authentication_method: bytes + authorization_message: bytes | None + client_channel_binding: bytes + client_first_message_bare: bytes | None + client_nonce: bytes | None + client_proof: bytes | None + password_salt: bytes | None + password_iterations: int + server_first_message: bytes | None + server_key: hmac.HMAC | None + server_nonce: bytes | None diff --git a/asyncpg/py.typed b/asyncpg/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/asyncpg/serverversion.py b/asyncpg/serverversion.py index 31568a2e..80fca72a 100644 --- a/asyncpg/serverversion.py +++ b/asyncpg/serverversion.py @@ -4,12 +4,14 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import re +import typing from .types import ServerVersion -version_regex = re.compile( +version_regex: typing.Final = re.compile( r"(Postgre[^\s]*)?\s*" r"(?P[0-9]+)\.?" r"((?P[0-9]+)\.?)?" @@ -19,7 +21,15 @@ ) -def split_server_version_string(version_string): +class _VersionDict(typing.TypedDict): + major: int + minor: int | None + micro: int | None + releaselevel: str | None + serial: int | None + + +def split_server_version_string(version_string: str) -> ServerVersion: version_match = version_regex.search(version_string) if version_match is None: @@ -28,17 +38,17 @@ def split_server_version_string(version_string): f'version from "{version_string}"' ) - version = version_match.groupdict() + version = typing.cast(_VersionDict, version_match.groupdict()) for ver_key, ver_value in version.items(): # Cast all possible versions parts to int try: - version[ver_key] = int(ver_value) + version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501 except (TypeError, ValueError): pass - if version.get("major") < 10: + if version["major"] < 10: return ServerVersion( - version.get("major"), + version["major"], version.get("minor") or 0, version.get("micro") or 0, version.get("releaselevel") or "final", @@ -52,7 +62,7 @@ def split_server_version_string(version_string): # want to keep that behaviour consistent, i.e not fail # a major version check due to a bugfix release. return ServerVersion( - version.get("major"), + version["major"], 0, version.get("minor") or 0, version.get("releaselevel") or "final", diff --git a/asyncpg/transaction.py b/asyncpg/transaction.py index 562811e6..59a6fe7f 100644 --- a/asyncpg/transaction.py +++ b/asyncpg/transaction.py @@ -4,12 +4,17 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import enum +import typing from . import connresource from . import exceptions as apg_errors +if typing.TYPE_CHECKING: + from . import connection as _connection + class TransactionState(enum.Enum): NEW = 0 @@ -19,13 +24,16 @@ class TransactionState(enum.Enum): FAILED = 4 -ISOLATION_LEVELS = { +IsolationLevels = typing.Literal[ + 'read_committed', 'read_uncommitted', 'serializable', 'repeatable_read' +] +ISOLATION_LEVELS: typing.Final[set[IsolationLevels]] = { 'read_committed', 'read_uncommitted', 'serializable', 'repeatable_read', } -ISOLATION_LEVELS_BY_VALUE = { +ISOLATION_LEVELS_BY_VALUE: typing.Final[dict[str, IsolationLevels]] = { 'read committed': 'read_committed', 'read uncommitted': 'read_uncommitted', 'serializable': 'serializable', @@ -41,10 +49,24 @@ class Transaction(connresource.ConnectionResource): function. """ - __slots__ = ('_connection', '_isolation', '_readonly', '_deferrable', + __slots__ = ('_isolation', '_readonly', '_deferrable', '_state', '_nested', '_id', '_managed') - def __init__(self, connection, isolation, readonly, deferrable): + _isolation: IsolationLevels | None + _readonly: bool + _deferrable: bool + _state: TransactionState + _nested: bool + _id: str | None + _managed: bool + + def __init__( + self, + connection: _connection.Connection[typing.Any], + isolation: IsolationLevels | None, + readonly: bool, + deferrable: bool, + ) -> None: super().__init__(connection) if isolation and isolation not in ISOLATION_LEVELS: @@ -60,14 +82,14 @@ def __init__(self, connection, isolation, readonly, deferrable): self._id = None self._managed = False - async def __aenter__(self): + async def __aenter__(self) -> None: if self._managed: raise apg_errors.InterfaceError( 'cannot enter context: already in an `async with` block') self._managed = True await self.start() - async def __aexit__(self, extype, ex, tb): + async def __aexit__(self, extype: object, ex: object, tb: object) -> None: try: self._check_conn_validity('__aexit__') except apg_errors.InterfaceError: @@ -93,7 +115,7 @@ async def __aexit__(self, extype, ex, tb): self._managed = False @connresource.guarded - async def start(self): + async def start(self) -> None: """Enter the transaction or savepoint block.""" self.__check_state_base('start') if self._state is TransactionState.STARTED: @@ -150,7 +172,7 @@ async def start(self): else: self._state = TransactionState.STARTED - def __check_state_base(self, opname): + def __check_state_base(self, opname: str) -> None: if self._state is TransactionState.COMMITTED: raise apg_errors.InterfaceError( 'cannot {}; the transaction is already committed'.format( @@ -164,7 +186,7 @@ def __check_state_base(self, opname): 'cannot {}; the transaction is in error state'.format( opname)) - def __check_state(self, opname): + def __check_state(self, opname: str) -> None: if self._state is not TransactionState.STARTED: if self._state is TransactionState.NEW: raise apg_errors.InterfaceError( @@ -172,7 +194,7 @@ def __check_state(self, opname): opname)) self.__check_state_base(opname) - async def __commit(self): + async def __commit(self) -> None: self.__check_state('commit') if self._connection._top_xact is self: @@ -191,7 +213,7 @@ async def __commit(self): else: self._state = TransactionState.COMMITTED - async def __rollback(self): + async def __rollback(self) -> None: self.__check_state('rollback') if self._connection._top_xact is self: @@ -211,7 +233,7 @@ async def __rollback(self): self._state = TransactionState.ROLLEDBACK @connresource.guarded - async def commit(self): + async def commit(self) -> None: """Exit the transaction or savepoint block and commit changes.""" if self._managed: raise apg_errors.InterfaceError( @@ -219,15 +241,15 @@ async def commit(self): await self.__commit() @connresource.guarded - async def rollback(self): + async def rollback(self) -> None: """Exit the transaction or savepoint block and rollback changes.""" if self._managed: raise apg_errors.InterfaceError( 'cannot manually rollback from within an `async with` block') await self.__rollback() - def __repr__(self): - attrs = [] + def __repr__(self) -> str: + attrs: list[str] = [] attrs.append('state:{}'.format(self._state.name.lower())) if self._isolation is not None: diff --git a/asyncpg/types.py b/asyncpg/types.py index bd5813fc..11055509 100644 --- a/asyncpg/types.py +++ b/asyncpg/types.py @@ -4,8 +4,17 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations -import collections +import typing + +if typing.TYPE_CHECKING: + import sys + + if sys.version_info < (3, 11): + from typing_extensions import Self + else: + from typing import Self from asyncpg.pgproto.types import ( BitString, Point, Path, Polygon, @@ -19,7 +28,13 @@ ) -Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema']) +class Type(typing.NamedTuple): + oid: int + name: str + kind: str + schema: str + + Type.__doc__ = 'Database data type.' Type.oid.__doc__ = 'OID of the type.' Type.name.__doc__ = 'Type name. For example "int2".' @@ -28,25 +43,61 @@ Type.schema.__doc__ = 'Name of the database schema that defines the type.' -Attribute = collections.namedtuple('Attribute', ['name', 'type']) +class Attribute(typing.NamedTuple): + name: str + type: Type + + Attribute.__doc__ = 'Database relation attribute.' Attribute.name.__doc__ = 'Attribute name.' Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.' -ServerVersion = collections.namedtuple( - 'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial']) +class ServerVersion(typing.NamedTuple): + major: int + minor: int + micro: int + releaselevel: str + serial: int + + ServerVersion.__doc__ = 'PostgreSQL server version tuple.' -class Range: - """Immutable representation of PostgreSQL `range` type.""" +class _RangeValue(typing.Protocol): + def __eq__(self, __value: object) -> bool: + ... + + def __lt__(self, __other: _RangeValue) -> bool: + ... + + def __gt__(self, __other: _RangeValue) -> bool: + ... + - __slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty' +_RV = typing.TypeVar('_RV', bound=_RangeValue) + + +class Range(typing.Generic[_RV]): + """Immutable representation of PostgreSQL `range` type.""" - def __init__(self, lower=None, upper=None, *, - lower_inc=True, upper_inc=False, - empty=False): + __slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty') + + _lower: _RV | None + _upper: _RV | None + _lower_inc: bool + _upper_inc: bool + _empty: bool + + def __init__( + self, + lower: _RV | None = None, + upper: _RV | None = None, + *, + lower_inc: bool = True, + upper_inc: bool = False, + empty: bool = False + ) -> None: self._empty = empty if empty: self._lower = self._upper = None @@ -58,34 +109,34 @@ def __init__(self, lower=None, upper=None, *, self._upper_inc = upper is not None and upper_inc @property - def lower(self): + def lower(self) -> _RV | None: return self._lower @property - def lower_inc(self): + def lower_inc(self) -> bool: return self._lower_inc @property - def lower_inf(self): + def lower_inf(self) -> bool: return self._lower is None and not self._empty @property - def upper(self): + def upper(self) -> _RV | None: return self._upper @property - def upper_inc(self): + def upper_inc(self) -> bool: return self._upper_inc @property - def upper_inf(self): + def upper_inf(self) -> bool: return self._upper is None and not self._empty @property - def isempty(self): + def isempty(self) -> bool: return self._empty - def _issubset_lower(self, other): + def _issubset_lower(self, other: Self) -> bool: if other._lower is None: return True if self._lower is None: @@ -96,7 +147,7 @@ def _issubset_lower(self, other): and (other._lower_inc or not self._lower_inc) ) - def _issubset_upper(self, other): + def _issubset_upper(self, other: Self) -> bool: if other._upper is None: return True if self._upper is None: @@ -107,7 +158,7 @@ def _issubset_upper(self, other): and (other._upper_inc or not self._upper_inc) ) - def issubset(self, other): + def issubset(self, other: Self) -> bool: if self._empty: return True if other._empty: @@ -115,13 +166,13 @@ def issubset(self, other): return self._issubset_lower(other) and self._issubset_upper(other) - def issuperset(self, other): + def issuperset(self, other: Self) -> bool: return other.issubset(self) - def __bool__(self): + def __bool__(self) -> bool: return not self._empty - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Range): return NotImplemented @@ -132,14 +183,14 @@ def __eq__(self, other): self._upper_inc, self._empty ) == ( - other._lower, - other._upper, + other._lower, # pyright: ignore [reportUnknownMemberType] + other._upper, # pyright: ignore [reportUnknownMemberType] other._lower_inc, other._upper_inc, other._empty ) - def __hash__(self): + def __hash__(self) -> int: return hash(( self._lower, self._upper, @@ -148,7 +199,7 @@ def __hash__(self): self._empty )) - def __repr__(self): + def __repr__(self) -> str: if self._empty: desc = 'empty' else: diff --git a/asyncpg/utils.py b/asyncpg/utils.py index 3940e04d..941ee585 100644 --- a/asyncpg/utils.py +++ b/asyncpg/utils.py @@ -4,24 +4,33 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import re +import typing +if typing.TYPE_CHECKING: + from . import connection -def _quote_ident(ident): + +def _quote_ident(ident: str) -> str: return '"{}"'.format(ident.replace('"', '""')) -def _quote_literal(string): +def _quote_literal(string: str) -> str: return "'{}'".format(string.replace("'", "''")) -async def _mogrify(conn, query, args): +async def _mogrify( + conn: connection.Connection[typing.Any], + query: str, + args: tuple[typing.Any, ...] +) -> str: """Safely inline arguments to query text.""" # Introspect the target query for argument types and # build a list of safely-quoted fully-qualified type names. ps = await conn.prepare(query) - paramtypes = [] + paramtypes: list[str] = [] for t in ps.get_parameters(): if t.name.endswith('[]'): pname = '_' + t.name[:-2] @@ -40,6 +49,9 @@ async def _mogrify(conn, query, args): textified = await conn.fetchrow( 'SELECT {cols}'.format(cols=', '.join(cols)), *args) + if typing.TYPE_CHECKING: + assert textified is not None + # Finally, replace $n references with text values. return re.sub( r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query) diff --git a/pyproject.toml b/pyproject.toml index ed2340a7..7c852418 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ "Topic :: Database :: Front-Ends", ] dependencies = [ - 'async_timeout>=4.0.3; python_version < "3.12.0"' + 'async_timeout>=4.0.3; python_version < "3.12.0"', ] [project.urls] @@ -37,7 +37,9 @@ github = "https://github.com/MagicStack/asyncpg" [project.optional-dependencies] test = [ 'flake8~=6.1', + 'flake8-pyi~=24.1.0', 'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"', + 'mypy~=1.8.0' ] docs = [ 'Sphinx~=5.3.0', @@ -102,3 +104,15 @@ exclude_lines = [ "if __name__ == .__main__.", ] show_missing = true + +[tool.mypy] +incremental = true +strict = true +implicit_reexport = true + +[[tool.mypy.overrides]] +module = [ + "asyncpg._testbase", + "asyncpg._testbase.*" +] +ignore_errors = true diff --git a/setup.py b/setup.py index c4d42d82..f7c3c471 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ with open(str(_ROOT / 'asyncpg' / '_version.py')) as f: for line in f: - if line.startswith('__version__ ='): + if line.startswith('__version__: typing.Final ='): _, _, version = line.partition('=') VERSION = version.strip(" \n'\"") break diff --git a/tests/test__sourcecode.py b/tests/test__sourcecode.py index 28ffdea7..b19044d4 100644 --- a/tests/test__sourcecode.py +++ b/tests/test__sourcecode.py @@ -14,7 +14,7 @@ def find_root(): return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -class TestFlake8(unittest.TestCase): +class TestCodeQuality(unittest.TestCase): def test_flake8(self): try: @@ -38,3 +38,34 @@ def test_flake8(self): output = ex.output.decode() raise AssertionError( 'flake8 validation failed:\n{}'.format(output)) from None + + def test_mypy(self): + try: + import mypy # NoQA + except ImportError: + raise unittest.SkipTest('mypy module is missing') + + root_path = find_root() + config_path = os.path.join(root_path, 'pyproject.toml') + if not os.path.exists(config_path): + raise RuntimeError('could not locate mypy.ini file') + + try: + subprocess.run( + [ + sys.executable, + '-m', + 'mypy', + '--config-file', + config_path, + 'asyncpg' + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=root_path + ) + except subprocess.CalledProcessError as ex: + output = ex.output.decode() + raise AssertionError( + 'mypy validation failed:\n{}'.format(output)) from None diff --git a/tools/generate_exceptions.py b/tools/generate_exceptions.py index 0b626558..bea0d30e 100755 --- a/tools/generate_exceptions.py +++ b/tools/generate_exceptions.py @@ -13,7 +13,8 @@ import string import textwrap -from asyncpg.exceptions import _base as apg_exc +from asyncpg.exceptions import _postgres_message as _pgm_exc +from asyncpg.exceptions import _base as _apg_exc _namemap = { @@ -87,14 +88,15 @@ class {clsname}({base}): buf = '# GENERATED FROM postgresql/src/backend/utils/errcodes.txt\n' + \ '# DO NOT MODIFY, use tools/generate_exceptions.py to update\n\n' + \ - 'from ._base import * # NOQA\nfrom . import _base\n\n\n' + 'from __future__ import annotations\n\n' + \ + 'import typing\nfrom ._base import * # NOQA\nfrom . import _base\n\n\n' classes = [] clsnames = set() def _add_class(clsname, base, sqlstate, docstring): if sqlstate: - sqlstate = "sqlstate = '{}'".format(sqlstate) + sqlstate = "sqlstate: typing.ClassVar[str] = '{}'".format(sqlstate) else: sqlstate = '' @@ -150,10 +152,10 @@ def _add_class(clsname, base, sqlstate, docstring): else: base = section_class - existing = apg_exc.PostgresMessageMeta.get_message_class_for_sqlstate( + existing = _pgm_exc.PostgresMessageMeta.get_message_class_for_sqlstate( sqlstate) - if (existing and existing is not apg_exc.UnknownPostgresError and + if (existing and existing is not _apg_exc.UnknownPostgresError and existing.__doc__): docstring = '"""{}"""\n\n '.format(existing.__doc__) else: @@ -164,7 +166,7 @@ def _add_class(clsname, base, sqlstate, docstring): subclasses = _subclassmap.get(sqlstate, []) for subclass in subclasses: - existing = getattr(apg_exc, subclass, None) + existing = getattr(_apg_exc, subclass, None) if existing and existing.__doc__: docstring = '"""{}"""\n\n '.format(existing.__doc__) else: @@ -176,7 +178,7 @@ def _add_class(clsname, base, sqlstate, docstring): buf += '\n\n\n'.join(classes) _all = textwrap.wrap(', '.join('{!r}'.format(c) for c in sorted(clsnames))) - buf += '\n\n\n__all__ = (\n {}\n)'.format( + buf += '\n\n\n__all__ = [\n {}\n]'.format( '\n '.join(_all)) buf += '\n\n__all__ += _base.__all__'