Skip to content

Commit

Permalink
feat(db): add async postgres db manager class
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Nov 21, 2023
1 parent 6d427e2 commit 5bc3424
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 0 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ Common utilities and helpers for Bento platform services.

## Running Tests

For tests to complete successfully, the following external servers must be running:

* A Redis server at `localhost:6379`
* A Postgres server at `localhost:5432`

Then, tests and linting can be run with the following command:

```bash
python3 -m tox
```
Expand Down
Empty file added bento_lib/db/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions bento_lib/db/pg_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import aiofiles
import asyncpg
import contextlib
from pathlib import Path
from typing import AsyncIterator


__all__ = [
"PgAsyncDatabase",
]


class PgAsyncDatabase:

def __init__(self, db_uri: str, schema_path: Path):
self._db_uri: str = db_uri
self._schema_path: Path = schema_path

self._pool: asyncpg.Pool | None = None

async def initialize(self, pool_size: int = 10):
conn: asyncpg.Connection

if not self._pool: # Initialize the connection pool if needed
self._pool = await asyncpg.create_pool(self._db_uri, min_size=pool_size, max_size=pool_size)

# Connect to the database and execute the schema script
async with aiofiles.open(self._schema_path, "r") as sf:
async with self.connect() as conn:
async with conn.transaction():
await conn.execute(await sf.read())

async def close(self):
if self._pool:
await self._pool.close()
self._pool = None

@contextlib.asynccontextmanager
async def connect(self, existing_conn: asyncpg.Connection | None = None) -> AsyncIterator[asyncpg.Connection]:
# TODO: raise raise DatabaseError("Pool is not available") when FastAPI has lifespan dependencies
# + manage pool lifespan in lifespan fn.

if self._pool is None:
await self.initialize() # initialize if this is the first time we're using the pool

if existing_conn is not None:
yield existing_conn
return

conn: asyncpg.Connection
async with self._pool.acquire() as conn:
yield conn
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ anyio==3.7.1
appdirs==1.4.4
asgiref==3.7.2
async-timeout==4.0.3
asyncpg==0.29.0
attrs==23.1.0
backports.entry-points-selectable==1.2.0
blinker==1.6.3
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"Werkzeug>=2.2.3,<4",
],
extras_require={
"asyncpg": ["asyncpg>=0.29.0,<0.30.0"],
"flask": ["Flask>=2.2.5,<4"],
"django": ["Django>=4.2.1,<5", "djangorestframework>=3.14.0,<3.15"],
"fastapi": ["fastapi>=0.100,<0.105"],
Expand Down
1 change: 1 addition & 0 deletions tests/data/test.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE IF NOT EXISTS test_table (id SERIAL PRIMARY KEY);
54 changes: 54 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pathlib
import asyncpg
import pytest
import pytest_asyncio
from bento_lib.db.pg_async import PgAsyncDatabase
from typing import AsyncGenerator


TEST_SCHEMA = pathlib.Path(__file__).parent / "data" / "test.sql"


async def get_test_db() -> AsyncGenerator[PgAsyncDatabase, None]:
db_instance = PgAsyncDatabase("postgresql://localhost:5432/postgres", TEST_SCHEMA)
await db_instance.initialize(pool_size=1) # Small pool size for testing
yield db_instance


db_fixture = pytest_asyncio.fixture(get_test_db, name="pg_async_db")


@pytest_asyncio.fixture
async def db_cleanup(pg_async_db: PgAsyncDatabase):
yield
conn: asyncpg.Connection
async with pg_async_db.connect() as conn:
await conn.execute("DROP TABLE IF EXISTS test_table")
await pg_async_db.close()


# noinspection PyUnusedLocal
@pytest.mark.asyncio
async def test_pg_async_db_open_close(pg_async_db: PgAsyncDatabase, db_cleanup):
await pg_async_db.close()
assert pg_async_db._pool is None

# duplicate request: should be idempotent
await pg_async_db.close()
assert pg_async_db._pool is None

# should not be able to connect
conn: asyncpg.Connection
async with pg_async_db.connect() as conn:
assert pg_async_db._pool is not None # Connection auto-initialized
async with pg_async_db.connect(existing_conn=conn) as conn2:
assert conn == conn2 # Re-using existing connection should be possible

# try re-opening
await pg_async_db.initialize()
assert pg_async_db._pool is not None
old_pool = pg_async_db._pool

# duplicate request: should be idempotent
await pg_async_db.initialize()
assert pg_async_db._pool == old_pool # same instance

0 comments on commit 5bc3424

Please sign in to comment.