diff --git a/DESCRIPTION.md b/DESCRIPTION.md index da226b07..47133034 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -12,6 +12,8 @@ Source code is also available at: - v1.5.1(Unreleased) - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. + - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. + - This fixes other boolean parameter passing to driver as well. - v1.5.0(Aug 23, 2023) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 350027f4..4fefa07f 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -11,7 +11,7 @@ from sqlalchemy import event as sa_vnt from sqlalchemy import exc as sa_exc from sqlalchemy import util as sa_util -from sqlalchemy.engine import default, reflection +from sqlalchemy.engine import URL, default, reflection from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name @@ -38,6 +38,7 @@ ) from snowflake.connector import errors as sf_errors +from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.constants import UTF8 from .base import ( @@ -62,7 +63,7 @@ _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name +from .util import _update_connection_application_name, parse_url_boolean colspecs = { Date: _CUSTOM_Date, @@ -109,7 +110,6 @@ "GEOMETRY": GEOMETRY, } - _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True @@ -199,7 +199,7 @@ def dbapi(cls): return connector - def create_connect_args(self, url): + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: name_spaces = [unquote_plus(e) for e in opts["database"].split("/")] @@ -226,10 +226,40 @@ def create_connect_args(self, url): opts["host"] = opts["host"] + ".snowflakecomputing.com" opts["port"] = "443" opts["autocommit"] = False # autocommit is disabled by default - opts.update(url.query) + + query = dict(**url.query) # make mutable + cache_column_metadata = query.pop("cache_column_metadata", None) self._cache_column_metadata = ( - opts.get("cache_column_metadata", "false").lower() == "true" + parse_url_boolean(cache_column_metadata) if cache_column_metadata else False ) + + # URL sets the query parameter values as strings, we need to cast to expected types when necessary + for name, value in query.items(): + maybe_type_configuration = DEFAULT_CONFIGURATION.get(name) + if ( + not maybe_type_configuration + ): # if the parameter is not found in the type mapping, pass it through as a string + opts[name] = value + continue + + (_, expected_type) = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance( + value, expected_type + ): # if the expected type is str, pass it through as a string + opts[name] = value + + elif ( + bool in expected_type + ): # if the expected type is bool, parse it and pass as a boolean + opts[name] = parse_url_boolean(value) + else: + # TODO: other types like int are stil passed through as string + # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447 + opts[name] = value + return ([], opts) def has_table(self, connection, table_name, schema=None): diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 631ceaee..54044349 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -115,6 +115,15 @@ def _update_connection_application_name(**conn_kwargs: Any) -> Any: return conn_kwargs +def parse_url_boolean(value: str) -> bool: + if value == "True": + return True + elif value == "False": + return False + else: + raise ValueError(f"Invalid boolean value detected: '{value}'") + + # handle Snowflake BCR bcr-1057 # the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState # which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that diff --git a/tests/test_core.py b/tests/test_core.py index 8206e43d..29c55ae9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -173,6 +173,27 @@ def test_connect_args(): engine.dispose() +def test_boolean_query_argument_parsing(): + engine = create_engine( + URL( + user=CONNECTION_PARAMETERS["user"], + password=CONNECTION_PARAMETERS["password"], + account=CONNECTION_PARAMETERS["account"], + host=CONNECTION_PARAMETERS["host"], + port=CONNECTION_PARAMETERS["port"], + protocol=CONNECTION_PARAMETERS["protocol"], + validate_default_parameters=True, + ) + ) + try: + verify_engine_connection(engine) + connection = engine.raw_connection() + assert connection.validate_default_parameters is True + finally: + connection.close() + engine.dispose() + + def test_create_dialect(): """ Tests getting only dialect object through create_engine