diff --git a/README.md b/README.md index d22da6c5..6da190fe 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,13 @@ Common utilities and helpers for Bento platform services. ## Running Tests +For tests to complete successfully, the following external servers must be running: + +* A Redis server at `localhost:6379` +* A Postgres server at `localhost:5432` + +Then, tests and linting can be run with the following command: + ```bash python3 -m tox ``` diff --git a/bento_lib/db/__init__.py b/bento_lib/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bento_lib/db/pg_async.py b/bento_lib/db/pg_async.py new file mode 100644 index 00000000..4d50f2b8 --- /dev/null +++ b/bento_lib/db/pg_async.py @@ -0,0 +1,52 @@ +import aiofiles +import asyncpg +import contextlib +from pathlib import Path +from typing import AsyncIterator + + +__all__ = [ + "PgAsyncDatabase", +] + + +class PgAsyncDatabase: + + def __init__(self, db_uri: str, schema_path: Path): + self._db_uri: str = db_uri + self._schema_path: Path = schema_path + + self._pool: asyncpg.Pool | None = None + + async def initialize(self, pool_size: int = 10): + conn: asyncpg.Connection + + if not self._pool: # Initialize the connection pool if needed + self._pool = await asyncpg.create_pool(self._db_uri, min_size=pool_size, max_size=pool_size) + + # Connect to the database and execute the schema script + async with aiofiles.open(self._schema_path, "r") as sf: + async with self.connect() as conn: + async with conn.transaction(): + await conn.execute(await sf.read()) + + async def close(self): + if self._pool: + await self._pool.close() + self._pool = None + + @contextlib.asynccontextmanager + async def connect(self, existing_conn: asyncpg.Connection | None = None) -> AsyncIterator[asyncpg.Connection]: + # TODO: raise raise DatabaseError("Pool is not available") when FastAPI has lifespan dependencies + # + manage pool lifespan in lifespan fn. + + if self._pool is None: + await self.initialize() # initialize if this is the first time we're using the pool + + if existing_conn is not None: + yield existing_conn + return + + conn: asyncpg.Connection + async with self._pool.acquire() as conn: + yield conn diff --git a/requirements.txt b/requirements.txt index 4649269b..459f1bb7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ anyio==3.7.1 appdirs==1.4.4 asgiref==3.7.2 async-timeout==4.0.3 +asyncpg==0.29.0 attrs==23.1.0 backports.entry-points-selectable==1.2.0 blinker==1.6.3 diff --git a/setup.py b/setup.py index 118387df..8259a036 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ "Werkzeug>=2.2.3,<4", ], extras_require={ + "asyncpg": ["asyncpg>=0.29.0,<0.30.0"], "flask": ["Flask>=2.2.5,<4"], "django": ["Django>=4.2.1,<5", "djangorestframework>=3.14.0,<3.15"], "fastapi": ["fastapi>=0.100,<0.105"], diff --git a/tests/data/test.sql b/tests/data/test.sql new file mode 100644 index 00000000..5f0c6f7a --- /dev/null +++ b/tests/data/test.sql @@ -0,0 +1 @@ +CREATE TABLE IF NOT EXISTS test_table (id SERIAL PRIMARY KEY); diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 00000000..76878f5c --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,54 @@ +import pathlib +import asyncpg +import pytest +import pytest_asyncio +from bento_lib.db.pg_async import PgAsyncDatabase +from typing import AsyncGenerator + + +TEST_SCHEMA = pathlib.Path(__file__).parent / "data" / "test.sql" + + +async def get_test_db() -> AsyncGenerator[PgAsyncDatabase, None]: + db_instance = PgAsyncDatabase("postgresql://localhost:5432/postgres", TEST_SCHEMA) + await db_instance.initialize(pool_size=1) # Small pool size for testing + yield db_instance + + +db_fixture = pytest_asyncio.fixture(get_test_db, name="pg_async_db") + + +@pytest_asyncio.fixture +async def db_cleanup(pg_async_db: PgAsyncDatabase): + yield + conn: asyncpg.Connection + async with pg_async_db.connect() as conn: + await conn.execute("DROP TABLE IF EXISTS test_table") + await pg_async_db.close() + + +# noinspection PyUnusedLocal +@pytest.mark.asyncio +async def test_pg_async_db_open_close(pg_async_db: PgAsyncDatabase, db_cleanup): + await pg_async_db.close() + assert pg_async_db._pool is None + + # duplicate request: should be idempotent + await pg_async_db.close() + assert pg_async_db._pool is None + + # should not be able to connect + conn: asyncpg.Connection + async with pg_async_db.connect() as conn: + assert pg_async_db._pool is not None # Connection auto-initialized + async with pg_async_db.connect(existing_conn=conn) as conn2: + assert conn == conn2 # Re-using existing connection should be possible + + # try re-opening + await pg_async_db.initialize() + assert pg_async_db._pool is not None + old_pool = pg_async_db._pool + + # duplicate request: should be idempotent + await pg_async_db.initialize() + assert pg_async_db._pool == old_pool # same instance