diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index a6ce7dca..04305a00 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -5,6 +5,7 @@ import operator from collections import defaultdict from functools import reduce +from typing import Any from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes @@ -63,7 +64,11 @@ _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name, parse_url_boolean +from .util import ( + _update_connection_application_name, + parse_url_boolean, + parse_url_integer, +) colspecs = { Date: _CUSTOM_Date, @@ -203,6 +208,26 @@ def import_dbapi(cls): return connector + @staticmethod + def parse_query_param_type(name: str, value: Any) -> Any: + """Cast param value if possible to type defined in connector-python.""" + if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)): + return value + + _, expected_type = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance(value, expected_type): + return value + + elif bool in expected_type: + return parse_url_boolean(value) + elif int in expected_type: + return parse_url_integer(value) + else: + return value + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: @@ -239,30 +264,7 @@ def create_connect_args(self, url: URL): # URL sets the query parameter values as strings, we need to cast to expected types when necessary for name, value in query.items(): - maybe_type_configuration = DEFAULT_CONFIGURATION.get(name) - if ( - not maybe_type_configuration - ): # if the parameter is not found in the type mapping, pass it through as a string - opts[name] = value - continue - - (_, expected_type) = maybe_type_configuration - if not isinstance(expected_type, tuple): - expected_type = (expected_type,) - - if isinstance( - value, expected_type - ): # if the expected type is str, pass it through as a string - opts[name] = value - - elif ( - bool in expected_type - ): # if the expected type is bool, parse it and pass as a boolean - opts[name] = parse_url_boolean(value) - else: - # TODO: other types like int are stil passed through as string - # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447 - opts[name] = value + opts[name] = self.parse_query_param_type(name, value) return ([], opts) @@ -281,7 +283,6 @@ def has_sequence(self, connection, sequence_name, schema=None, **kw): return self._has_object(connection, "SEQUENCE", sequence_name, schema) def _has_object(self, connection, object_type, object_name, schema=None): - full_name = self._denormalize_quote_join(schema, object_name) try: results = connection.execute( diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 1738db3e..a1aefff9 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -125,6 +125,13 @@ def parse_url_boolean(value: str) -> bool: raise ValueError(f"Invalid boolean value detected: '{value}'") +def parse_url_integer(value: str) -> int: + try: + return int(value) + except ValueError as e: + raise ValueError(f"Invalid int value detected: '{value}") from e + + # handle Snowflake BCR bcr-1057 # the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState # which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that diff --git a/tests/test_core.py b/tests/test_core.py index a594b27a..179133c8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -34,7 +34,7 @@ inspect, text, ) -from sqlalchemy.exc import DBAPIError, NoSuchTableError +from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError from sqlalchemy.sql import and_, not_, or_, select import snowflake.connector.errors @@ -1059,28 +1059,15 @@ def harass_inspector(): assert outcome -@pytest.mark.timeout(15) -def test_region(): - engine = create_engine( - URL( - user="testuser", - password="testpassword", - account="testaccount", - region="eu-central-1", - login_timeout=5, - ) - ) - try: - engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.eu-central-1.snowflakecomputing.com" in ex.orig.msg - - -@pytest.mark.timeout(15) -def test_azure(): +@pytest.mark.timeout(10) +@pytest.mark.parametrize( + "region", + ( + pytest.param("eu-central-1", id="region"), + pytest.param("east-us-2.azure", id="azure"), + ), +) +def test_connection_timeout_error(region): engine = create_engine( URL( user="testuser", @@ -1090,13 +1077,13 @@ def test_azure(): login_timeout=5, ) ) - try: + + with pytest.raises(OperationalError) as excinfo: engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.east-us-2.azure.snowflakecomputing.com" in ex.orig.msg + + assert excinfo.value.orig.errno == 250001 + assert "Could not connect to Snowflake backend" in excinfo.value.orig.msg + assert region not in excinfo.value.orig.msg def test_load_dialect():