Skip to content

Commit

Permalink
[Trino] Add a flag to distinguish between the database and table/colu…
Browse files Browse the repository at this point in the history
…mn name
  • Loading branch information
agl29 committed Oct 17, 2024
1 parent 7f6aab6 commit 60f478a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
14 changes: 7 additions & 7 deletions desktop/libs/notebook/src/notebook/connectors/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def get_auth_password(self):
else DEFAULT_AUTH_PASSWORD.get()
)

def _format_identifier(self, identifier):
def _format_identifier(self, identifier, is_db=False):
# Remove any backticks
identifier = identifier.replace('`', '')

# Check if already formatted
if not (identifier.startswith('"') and identifier.endswith('"')):
# Check if it's a multi-part identifier (e.g., catalog.schema)
if '.' in identifier:
if '.' in identifier and is_db:
# Split and format each part separately
identifier = '"{}"'.format('"."'.join(identifier.split('.')))
else:
Expand All @@ -113,7 +113,7 @@ def create_session(self, lang=None, properties=None):
@query_error_handler
def execute(self, notebook, snippet):
database = snippet['database']
database = self._format_identifier(database)
database = self._format_identifier(database, is_db=True)
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
query_client.execute()

Expand Down Expand Up @@ -258,7 +258,7 @@ def _get_select_query(self, database, table, column=None, operation=None, limit=
if operation == 'hello':
statement = "SELECT 'Hello World!'"
else:
database = self._format_identifier(database)
database = self._format_identifier(database, is_db=True)
table = self._format_identifier(table)
column = '%(column)s' % {'column': self._format_identifier(column)} if column else '*'
statement = textwrap.dedent('''\
Expand Down Expand Up @@ -312,7 +312,7 @@ def _show_catalogs(self):
return catalogs

def _show_tables(self, database):
database = self._format_identifier(database)
database = self._format_identifier(database, is_db=True)
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
query_client.execute()
query_client = TrinoQuery(self.trino_request, 'SHOW TABLES')
Expand All @@ -327,7 +327,7 @@ def _show_tables(self, database):
]

def _get_columns(self, database, table):
database = self._format_identifier(database)
database = self._format_identifier(database, is_db=True)
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
query_client.execute()
table = self._format_identifier(table)
Expand Down Expand Up @@ -356,7 +356,7 @@ def explain(self, notebook, snippet):
if statement:
try:
database = snippet['database']
database = self._format_identifier(database)
database = self._format_identifier(database, is_db=True)
TrinoQuery(self.trino_request, 'USE ' + database).execute()
result = TrinoQuery(self.trino_request, 'EXPLAIN ' + statement).execute()
explanation = result.rows
Expand Down
14 changes: 12 additions & 2 deletions desktop/libs/notebook/src/notebook/connectors/trino_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,23 @@ def setup_class(cls):
cls.trino_api = TrinoApi(cls.user, interpreter=cls.interpreter)

def test_format_identifier(self):
# db name test
test_cases = [
("my_db", '"my_db"'),
("my_db.table", '"my_db"."table"'),
("my_catalog.my_db", '"my_catalog"."my_db"'),
]

for database, expected_output in test_cases:
assert self.trino_api._format_identifier(database) == expected_output
assert self.trino_api._format_identifier(database, is_db=True) == expected_output

# table name test
test_cases = [
("io.airlift.discovery.store:name=dynamic,type=distributedstore", '"io.airlift.discovery.store:name=dynamic,type=distributedstore"'),
("table", '"table"'),
]

for table, expected_output in test_cases:
assert self.trino_api._format_identifier(table) == expected_output

def test_parse_api_url(self):
# Test parse_api_url method
Expand Down

0 comments on commit 60f478a

Please sign in to comment.