-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(db): add async postgres db manager class
- Loading branch information
1 parent
6d427e2
commit 5bc3424
Showing
7 changed files
with
116 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import aiofiles | ||
import asyncpg | ||
import contextlib | ||
from pathlib import Path | ||
from typing import AsyncIterator | ||
|
||
|
||
__all__ = [ | ||
"PgAsyncDatabase", | ||
] | ||
|
||
|
||
class PgAsyncDatabase: | ||
|
||
def __init__(self, db_uri: str, schema_path: Path): | ||
self._db_uri: str = db_uri | ||
self._schema_path: Path = schema_path | ||
|
||
self._pool: asyncpg.Pool | None = None | ||
|
||
async def initialize(self, pool_size: int = 10): | ||
conn: asyncpg.Connection | ||
|
||
if not self._pool: # Initialize the connection pool if needed | ||
self._pool = await asyncpg.create_pool(self._db_uri, min_size=pool_size, max_size=pool_size) | ||
|
||
# Connect to the database and execute the schema script | ||
async with aiofiles.open(self._schema_path, "r") as sf: | ||
async with self.connect() as conn: | ||
async with conn.transaction(): | ||
await conn.execute(await sf.read()) | ||
|
||
async def close(self): | ||
if self._pool: | ||
await self._pool.close() | ||
self._pool = None | ||
|
||
@contextlib.asynccontextmanager | ||
async def connect(self, existing_conn: asyncpg.Connection | None = None) -> AsyncIterator[asyncpg.Connection]: | ||
# TODO: raise raise DatabaseError("Pool is not available") when FastAPI has lifespan dependencies | ||
# + manage pool lifespan in lifespan fn. | ||
|
||
if self._pool is None: | ||
await self.initialize() # initialize if this is the first time we're using the pool | ||
|
||
if existing_conn is not None: | ||
yield existing_conn | ||
return | ||
|
||
conn: asyncpg.Connection | ||
async with self._pool.acquire() as conn: | ||
yield conn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CREATE TABLE IF NOT EXISTS test_table (id SERIAL PRIMARY KEY); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import pathlib | ||
import asyncpg | ||
import pytest | ||
import pytest_asyncio | ||
from bento_lib.db.pg_async import PgAsyncDatabase | ||
from typing import AsyncGenerator | ||
|
||
|
||
TEST_SCHEMA = pathlib.Path(__file__).parent / "data" / "test.sql" | ||
|
||
|
||
async def get_test_db() -> AsyncGenerator[PgAsyncDatabase, None]: | ||
db_instance = PgAsyncDatabase("postgresql://localhost:5432/postgres", TEST_SCHEMA) | ||
await db_instance.initialize(pool_size=1) # Small pool size for testing | ||
yield db_instance | ||
|
||
|
||
db_fixture = pytest_asyncio.fixture(get_test_db, name="pg_async_db") | ||
|
||
|
||
@pytest_asyncio.fixture | ||
async def db_cleanup(pg_async_db: PgAsyncDatabase): | ||
yield | ||
conn: asyncpg.Connection | ||
async with pg_async_db.connect() as conn: | ||
await conn.execute("DROP TABLE IF EXISTS test_table") | ||
await pg_async_db.close() | ||
|
||
|
||
# noinspection PyUnusedLocal | ||
@pytest.mark.asyncio | ||
async def test_pg_async_db_open_close(pg_async_db: PgAsyncDatabase, db_cleanup): | ||
await pg_async_db.close() | ||
assert pg_async_db._pool is None | ||
|
||
# duplicate request: should be idempotent | ||
await pg_async_db.close() | ||
assert pg_async_db._pool is None | ||
|
||
# should not be able to connect | ||
conn: asyncpg.Connection | ||
async with pg_async_db.connect() as conn: | ||
assert pg_async_db._pool is not None # Connection auto-initialized | ||
async with pg_async_db.connect(existing_conn=conn) as conn2: | ||
assert conn == conn2 # Re-using existing connection should be possible | ||
|
||
# try re-opening | ||
await pg_async_db.initialize() | ||
assert pg_async_db._pool is not None | ||
old_pool = pg_async_db._pool | ||
|
||
# duplicate request: should be idempotent | ||
await pg_async_db.initialize() | ||
assert pg_async_db._pool == old_pool # same instance |