From 60f478a7aa2c14a9d0c0f4e918a4ef3711fe4b5c Mon Sep 17 00:00:00 2001 From: agl29 Date: Wed, 16 Oct 2024 13:54:07 +0530 Subject: [PATCH] [Trino] Add a flag to distinguish between the database and table/column name --- .../libs/notebook/src/notebook/connectors/trino.py | 14 +++++++------- .../src/notebook/connectors/trino_tests.py | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/desktop/libs/notebook/src/notebook/connectors/trino.py b/desktop/libs/notebook/src/notebook/connectors/trino.py index 4cf626a5b56..921b1ce0b7c 100644 --- a/desktop/libs/notebook/src/notebook/connectors/trino.py +++ b/desktop/libs/notebook/src/notebook/connectors/trino.py @@ -85,14 +85,14 @@ def get_auth_password(self): else DEFAULT_AUTH_PASSWORD.get() ) - def _format_identifier(self, identifier): + def _format_identifier(self, identifier, is_db=False): # Remove any backticks identifier = identifier.replace('`', '') # Check if already formatted if not (identifier.startswith('"') and identifier.endswith('"')): # Check if it's a multi-part identifier (e.g., catalog.schema) - if '.' in identifier: + if '.' in identifier and is_db: # Split and format each part separately identifier = '"{}"'.format('"."'.join(identifier.split('.'))) else: @@ -113,7 +113,7 @@ def create_session(self, lang=None, properties=None): @query_error_handler def execute(self, notebook, snippet): database = snippet['database'] - database = self._format_identifier(database) + database = self._format_identifier(database, is_db=True) query_client = TrinoQuery(self.trino_request, 'USE ' + database) query_client.execute() @@ -258,7 +258,7 @@ def _get_select_query(self, database, table, column=None, operation=None, limit= if operation == 'hello': statement = "SELECT 'Hello World!'" else: - database = self._format_identifier(database) + database = self._format_identifier(database, is_db=True) table = self._format_identifier(table) column = '%(column)s' % {'column': self._format_identifier(column)} if column else '*' statement = textwrap.dedent('''\ @@ -312,7 +312,7 @@ def _show_catalogs(self): return catalogs def _show_tables(self, database): - database = self._format_identifier(database) + database = self._format_identifier(database, is_db=True) query_client = TrinoQuery(self.trino_request, 'USE ' + database) query_client.execute() query_client = TrinoQuery(self.trino_request, 'SHOW TABLES') @@ -327,7 +327,7 @@ def _show_tables(self, database): ] def _get_columns(self, database, table): - database = self._format_identifier(database) + database = self._format_identifier(database, is_db=True) query_client = TrinoQuery(self.trino_request, 'USE ' + database) query_client.execute() table = self._format_identifier(table) @@ -356,7 +356,7 @@ def explain(self, notebook, snippet): if statement: try: database = snippet['database'] - database = self._format_identifier(database) + database = self._format_identifier(database, is_db=True) TrinoQuery(self.trino_request, 'USE ' + database).execute() result = TrinoQuery(self.trino_request, 'EXPLAIN ' + statement).execute() explanation = result.rows diff --git a/desktop/libs/notebook/src/notebook/connectors/trino_tests.py b/desktop/libs/notebook/src/notebook/connectors/trino_tests.py index fa08c5ba916..10aa7f7e4d5 100644 --- a/desktop/libs/notebook/src/notebook/connectors/trino_tests.py +++ b/desktop/libs/notebook/src/notebook/connectors/trino_tests.py @@ -42,13 +42,23 @@ def setup_class(cls): cls.trino_api = TrinoApi(cls.user, interpreter=cls.interpreter) def test_format_identifier(self): + # db name test test_cases = [ ("my_db", '"my_db"'), - ("my_db.table", '"my_db"."table"'), + ("my_catalog.my_db", '"my_catalog"."my_db"'), ] for database, expected_output in test_cases: - assert self.trino_api._format_identifier(database) == expected_output + assert self.trino_api._format_identifier(database, is_db=True) == expected_output + + # table name test + test_cases = [ + ("io.airlift.discovery.store:name=dynamic,type=distributedstore", '"io.airlift.discovery.store:name=dynamic,type=distributedstore"'), + ("table", '"table"'), + ] + + for table, expected_output in test_cases: + assert self.trino_api._format_identifier(table) == expected_output def test_parse_api_url(self): # Test parse_api_url method