From 65ab4dad14fd4aa57c494c209556fee525c3d71d Mon Sep 17 00:00:00 2001 From: Yury Pliner Date: Sat, 5 Dec 2020 14:57:10 +0500 Subject: [PATCH] Fixes Engine.release method to release connection in any way --- aiopg/sa/engine.py | 5 --- tests/conftest.py | 73 +++++++++++++++++++++++++++++++++++++++ tests/test_async_await.py | 2 +- tests/test_sa_engine.py | 28 +++++++++++++-- 4 files changed, 99 insertions(+), 9 deletions(-) diff --git a/aiopg/sa/engine.py b/aiopg/sa/engine.py index 30c7cf48..1e772686 100644 --- a/aiopg/sa/engine.py +++ b/aiopg/sa/engine.py @@ -5,7 +5,6 @@ from ..connection import TIMEOUT from ..utils import _PoolAcquireContextManager, _PoolContextManager from .connection import SAConnection -from .exc import InvalidRequestError try: from sqlalchemy.dialects.postgresql.psycopg2 import ( @@ -169,10 +168,6 @@ async def _acquire(self): return conn def release(self, conn): - """Revert back connection to pool.""" - if conn.in_transaction: - raise InvalidRequestError("Cannot release a connection with " - "not finished transaction") raw = conn.connection fut = self._pool.release(raw) return fut diff --git a/tests/conftest.py b/tests/conftest.py index c83bd004..f3555a36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -391,3 +391,76 @@ def warning(): @pytest.fixture def log(): yield _AssertLogsContext + + +@pytest.fixture +def tcp_proxy(loop): + proxy = None + + async def go(src_port, dst_port): + nonlocal proxy + proxy = TcpProxy( + dst_port=dst_port, + src_port=src_port, + ) + await proxy.start() + return proxy + yield go + if proxy is not None: + loop.run_until_complete(proxy.disconnect()) + + +class TcpProxy: + """ + TCP proxy. Allows simulating connection breaks in tests. + """ + MAX_BYTES = 1024 + + def __init__(self, *, src_port, dst_port): + self.src_host = '127.0.0.1' + self.src_port = src_port + self.dst_host = '127.0.0.1' + self.dst_port = dst_port + self.connections = set() + + async def start(self): + return await asyncio.start_server( + self.handle_client, + host=self.src_host, + port=self.src_port, + ) + + async def disconnect(self): + while self.connections: + writer = self.connections.pop() + writer.close() + await writer.wait_closed() + + @staticmethod + async def _pipe( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + try: + while not reader.at_eof(): + bytes_read = await reader.read(TcpProxy.MAX_BYTES) + writer.write(bytes_read) + finally: + writer.close() + + async def handle_client( + self, + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + ): + server_reader, server_writer = await asyncio.open_connection( + host=self.dst_host, + port=self.dst_port + ) + + self.connections.add(server_writer) + self.connections.add(client_writer) + + await asyncio.wait([ + self._pipe(server_reader, client_writer), + self._pipe(client_reader, server_writer), + ]) diff --git a/tests/test_async_await.py b/tests/test_async_await.py index 0ad34e25..480119cf 100644 --- a/tests/test_async_await.py +++ b/tests/test_async_await.py @@ -55,7 +55,7 @@ async def test_pool_context_manager_timeout(pg_params, loop): async with aiopg.create_pool(**pg_params, minsize=1, maxsize=1) as pool: cursor_ctx = await pool.cursor() - with pytest.warns(ResourceWarning): + with pytest.warns(ResourceWarning, match='Invalid transaction status'): with cursor_ctx as cursor: hung_task = cursor.execute('SELECT pg_sleep(10000);') # start task diff --git a/tests/test_sa_engine.py b/tests/test_sa_engine.py index 1d44423b..b9ec6a66 100644 --- a/tests/test_sa_engine.py +++ b/tests/test_sa_engine.py @@ -1,5 +1,6 @@ import asyncio +import psycopg2 import pytest from psycopg2.extensions import parse_dsn from sqlalchemy import Column, Integer, MetaData, String, Table @@ -80,10 +81,10 @@ def test_not_context_manager(engine): async def test_release_transacted(engine): conn = await engine.acquire() tr = await conn.begin() - with pytest.raises(sa.InvalidRequestError): - engine.release(conn) + with pytest.warns(ResourceWarning, match='Invalid transaction status'): + await engine.release(conn) del tr - await conn.close() + assert conn.closed def test_timeout(engine): @@ -147,3 +148,24 @@ async def test_terminate_with_acquired_connections(make_engine): await engine.wait_closed() assert conn.closed + + +async def test_release_disconnected_connection( + tcp_proxy, unused_port, pg_params, make_engine +): + server_port = pg_params["port"] + proxy_port = unused_port() + + tcp_proxy = await tcp_proxy(proxy_port, server_port) + engine = await make_engine(port=proxy_port) + + with pytest.raises( + psycopg2.InterfaceError, match='connection already closed' + ): + with pytest.warns(ResourceWarning, match='Invalid transaction status'): + async with engine.acquire() as conn, conn.begin(): + await conn.execute('SELECT 1;') + await tcp_proxy.disconnect() + await conn.execute('SELECT 1;') + + assert engine.size == 0