Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use asyncpg dsn #704

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions src/gino/dialects/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,9 @@ def __init__(self, *pargs, **kwargs):
self.baked_queries = {}

args.update(
loop=self._loop,
host=self._url.host,
port=self._url.port,
user=self._url.username,
database=self._url.database,
password=self._url.password,
connection_class=Connection,
connection_class=Connection, dsn=str(self._url), loop=self._loop,
)

if self._prebake and self._bakery:
self._init_hook = args.pop("init", None)
args["init"] = self._bake
Expand Down
20 changes: 11 additions & 9 deletions src/gino/strategies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from copy import copy

from sqlalchemy.engine import url
from sqlalchemy import util
from sqlalchemy.engine.url import make_url
from sqlalchemy.engine.strategies import EngineStrategy

from .engine import GinoEngine
from .dialects.asyncpg import AsyncpgDialect


class GinoStrategy(EngineStrategy):
Expand All @@ -14,23 +15,24 @@ class GinoStrategy(EngineStrategy):
This strategy is initialized automatically as :mod:`gino` is imported.

If :func:`sqlalchemy.create_engine` uses ``strategy="gino"``, it will return a
:class:`~collections.abc.Coroutine`, and treat URL prefix ``postgresql://`` or
``postgres://`` as ``postgresql+asyncpg://``.
:class:`~collections.abc.Coroutine`.
"""

name = "gino"
engine_cls = GinoEngine

async def create(self, name_or_url, loop=None, **kwargs):
engine_cls = self.engine_cls
u = url.make_url(name_or_url)
url = make_url(name_or_url)
if loop is None:
loop = asyncio.get_event_loop()
if u.drivername in {"postgresql", "postgres"}:
u = copy(u)
u.drivername = "postgresql+asyncpg"

dialect_cls = u.get_dialect()
# The postgresql dialect is already taken by the PGDialect_psycopg2
# we need to force ourone.
if url.drivername in ("postgresql", "postgres"):
dialect_cls = AsyncpgDialect
else:
dialect_cls = url.get_dialect()

pop_kwarg = kwargs.pop

Expand All @@ -52,7 +54,7 @@ async def create(self, name_or_url, loop=None, **kwargs):

dialect = dialect_cls(**dialect_args)
pool_class = kwargs.pop("pool_class", None)
pool = await dialect.init_pool(u, loop, pool_class=pool_class)
pool = await dialect.init_pool(url, loop, pool_class=pool_class)

engine_args = dict(loop=loop)
for k in util.get_cls_kwargs(engine_cls):
Expand Down
31 changes: 25 additions & 6 deletions tests/test_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from gino.exceptions import UninitializedError
from sqlalchemy.engine.url import make_url

from .models import db, PG_URL, User
from .models import db, DB_ARGS, PG_URL, User

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -55,9 +55,28 @@ async def test_db_api(bind, random_name):
assert params[0] == 3


async def test_bind_url():
url = make_url(PG_URL)
assert url.drivername == "postgresql"
await db.set_bind(PG_URL)
assert url.drivername == "postgresql"
@pytest.mark.parametrize(
"dsn, driver_name",
(
(
"postgresql://{user}:{password}@{host}:{port}/{database}".format(**DB_ARGS),
"postgresql",
),
(
"postgres://{user}:{password}@{host}:{port}/{database}".format(**DB_ARGS),
"postgres",
),
(
"postgres://{user}:{password}@/{database}?host={host}&port={port}".format(
**DB_ARGS
),
"postgres",
),
),
)
async def test_bind_url(dsn, driver_name):
url = make_url(dsn)
assert url.drivername == driver_name
await db.set_bind(dsn)
assert url.drivername == driver_name
await db.pop_bind().close()