Skip to content

Commit

Permalink
Fix on_connect multiple call on acquire (#552)
Browse files Browse the repository at this point in the history
* Make test_pool_on_connect check for multiple calls

* Move on_connect call to after new connection made

* Update docs to reflect behaviour of on_connect

* Test pool on_connect with different paths in _fill_free_pool
  • Loading branch information
aaliddell authored Dec 5, 2020
1 parent 047dee8 commit 7f8a846
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
6 changes: 4 additions & 2 deletions aiopg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ async def _acquire(self):
assert not conn.closed, conn
assert conn not in self._used, (conn, self._used)
self._used.add(conn)
if self._on_connect is not None:
await self._on_connect(conn)
return conn
else:
await self._cond.wait()
Expand Down Expand Up @@ -197,6 +195,8 @@ async def _fill_free_pool(self, override_min):
enable_uuid=self._enable_uuid,
echo=self._echo,
**self._conn_kwargs)
if self._on_connect is not None:
await self._on_connect(conn)
# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()
Expand All @@ -215,6 +215,8 @@ async def _fill_free_pool(self, override_min):
enable_uuid=self._enable_uuid,
echo=self._echo,
**self._conn_kwargs)
if self._on_connect is not None:
await self._on_connect(conn)
# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()
Expand Down
2 changes: 1 addition & 1 deletion docs/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ The basic usage is::

:param bool echo: executed log SQL queryes (disabled by default).

:param on_connect: a *callback coroutine* executed at once for every
:param on_connect: a *callback coroutine* executed once for every
created connection. May be used for setting up connection level
state like client encoding etc.

Expand Down
20 changes: 14 additions & 6 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,20 +546,28 @@ async def test_close_running_cursor(create_pool):
await cur.execute('SELECT pg_sleep(10)')


async def test_pool_on_connect(create_pool):
called = False
@pytest.mark.parametrize('pool_minsize', [0, 1])
async def test_pool_on_connect(create_pool, pool_minsize):
cb_called_times = 0

async def cb(connection):
nonlocal called
nonlocal cb_called_times
async with connection.cursor() as cur:
await cur.execute('SELECT 1')
data = await cur.fetchall()
assert [(1,)] == data
called = True
cb_called_times += 1

pool = await create_pool(on_connect=cb)
pool = await create_pool(
minsize=pool_minsize,
maxsize=1,
on_connect=cb
)

with (await pool.cursor()) as cur:
await cur.execute('SELECT 1')

with (await pool.cursor()) as cur:
await cur.execute('SELECT 1')

assert called
assert cb_called_times == 1

0 comments on commit 7f8a846

Please sign in to comment.