diff --git a/aiopg/pool.py b/aiopg/pool.py index 41908a9d..61cbade7 100644 --- a/aiopg/pool.py +++ b/aiopg/pool.py @@ -167,8 +167,6 @@ async def _acquire(self): assert not conn.closed, conn assert conn not in self._used, (conn, self._used) self._used.add(conn) - if self._on_connect is not None: - await self._on_connect(conn) return conn else: await self._cond.wait() @@ -197,6 +195,8 @@ async def _fill_free_pool(self, override_min): enable_uuid=self._enable_uuid, echo=self._echo, **self._conn_kwargs) + if self._on_connect is not None: + await self._on_connect(conn) # raise exception if pool is closing self._free.append(conn) self._cond.notify() @@ -215,6 +215,8 @@ async def _fill_free_pool(self, override_min): enable_uuid=self._enable_uuid, echo=self._echo, **self._conn_kwargs) + if self._on_connect is not None: + await self._on_connect(conn) # raise exception if pool is closing self._free.append(conn) self._cond.notify() diff --git a/docs/core.rst b/docs/core.rst index f693e44f..d2a5625b 100644 --- a/docs/core.rst +++ b/docs/core.rst @@ -761,7 +761,7 @@ The basic usage is:: :param bool echo: executed log SQL queryes (disabled by default). - :param on_connect: a *callback coroutine* executed at once for every + :param on_connect: a *callback coroutine* executed once for every created connection. May be used for setting up connection level state like client encoding etc. diff --git a/tests/test_pool.py b/tests/test_pool.py index d06af613..4bc2a1fe 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -546,20 +546,28 @@ async def test_close_running_cursor(create_pool): await cur.execute('SELECT pg_sleep(10)') -async def test_pool_on_connect(create_pool): - called = False +@pytest.mark.parametrize('pool_minsize', [0, 1]) +async def test_pool_on_connect(create_pool, pool_minsize): + cb_called_times = 0 async def cb(connection): - nonlocal called + nonlocal cb_called_times async with connection.cursor() as cur: await cur.execute('SELECT 1') data = await cur.fetchall() assert [(1,)] == data - called = True + cb_called_times += 1 - pool = await create_pool(on_connect=cb) + pool = await create_pool( + minsize=pool_minsize, + maxsize=1, + on_connect=cb + ) + + with (await pool.cursor()) as cur: + await cur.execute('SELECT 1') with (await pool.cursor()) as cur: await cur.execute('SELECT 1') - assert called + assert cb_called_times == 1