Skip to content

Commit

Permalink
Improve index reflection (#556)
Browse files Browse the repository at this point in the history
Improve index reflection
  • Loading branch information
sfc-gh-jvasquezrojas authored Dec 12, 2024
1 parent 695c0a9 commit 716683f
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 106 deletions.
7 changes: 5 additions & 2 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ Source code is also available at:

# Release Notes

- (Unreleased)
- Fix quoting of `_` as column name
- Fix index columns was not being reflected
- Fix index reflection cache not working

- v1.7.1(December 02, 2024)
- Add support for partition by to copy into <location>
- Fix BOOLEAN type not found in snowdialect

- v1.7.0(November 21, 2024)

- Fixed quoting of `_` as column name
- Add support for dynamic tables and required options
- Add support for hybrid tables
- Fixed SAWarning when registering functions with existing name in default namespace
Expand Down
16 changes: 16 additions & 0 deletions src/snowflake/sqlalchemy/parser/custom_type_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
from typing import List

import sqlalchemy.types as sqltypes
from sqlalchemy.sql.type_api import TypeEngine
Expand Down Expand Up @@ -107,6 +108,21 @@ def extract_parameters(text: str) -> list:
return output_parameters


def parse_index_columns(columns: str) -> List[str]:
"""
Parses a string with a list of columns for an index.
:param columns: A string with a list of columns for an index, which may include parentheses.
:param compiler: A SQLAlchemy compiler.
:return: A list of columns as strings.
:example:
For input `"[A, B, C]"`, the output is `['A', 'B', 'C']`.
"""
return [column.strip() for column in columns.strip("[]").split(",")]


def parse_type(type_text: str) -> TypeEngine:
"""
Parses a type definition string and returns the corresponding SQLAlchemy type.
Expand Down
184 changes: 84 additions & 100 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from collections import defaultdict
from functools import reduce
from typing import Any
from typing import Any, Collection, Optional
from urllib.parse import unquote_plus

import sqlalchemy.types as sqltypes
Expand Down Expand Up @@ -41,7 +41,7 @@
)
from .parser.custom_type_parser import * # noqa
from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa
from .parser.custom_type_parser import ischema_names, parse_type
from .parser.custom_type_parser import ischema_names, parse_index_columns, parse_type
from .sql.custom_schema.custom_table_prefix import CustomTablePrefix
from .util import (
_update_connection_application_name,
Expand Down Expand Up @@ -674,27 +674,43 @@ def get_columns(self, connection, table_name, schema=None, **kw):
raise sa_exc.NoSuchTableError()
return schema_columns[normalized_table_name]

def get_prefixes_from_data(self, name_to_index_map, row, **kw):
prefixes_found = []
for valid_prefix in CustomTablePrefix:
key = f"is_{valid_prefix.name.lower()}"
if key in name_to_index_map and row[name_to_index_map[key]] == "Y":
prefixes_found.append(valid_prefix.name)
return prefixes_found

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
def _get_schema_tables_info(self, connection, schema=None, **kw):
"""
Gets all table names.
Retrieves information about all tables in the specified schema.
"""

schema = schema or self.default_schema_name
current_schema = schema
if schema:
cursor = connection.execute(
text(
f"SHOW /* sqlalchemy:get_table_names */ TABLES IN {self._denormalize_quote_join(schema)}"
)
)
else:
cursor = connection.execute(
text("SHOW /* sqlalchemy:get_table_names */ TABLES")
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {self._denormalize_quote_join(schema)}"
)
_, current_schema = self._current_database_schema(connection)
)

ret = [self.normalize_name(row[1]) for row in cursor]
name_to_index_map = self._map_name_to_idx(result)
tables = {}
for row in result.cursor.fetchall():
table_name = self.normalize_name(str(row[name_to_index_map["name"]]))
table_prefixes = self.get_prefixes_from_data(name_to_index_map, row)
tables[table_name] = {"prefixes": table_prefixes}

return tables

def get_table_names(self, connection, schema=None, **kw):
"""
Gets all table names.
"""
ret = self._get_schema_tables_info(
connection, schema, info_cache=kw.get("info_cache", None)
).keys()
return ret

@reflection.cache
Expand Down Expand Up @@ -748,17 +764,12 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):

def get_temp_table_names(self, connection, schema=None, **kw):
schema = schema or self.default_schema_name
if schema:
cursor = connection.execute(
text(
f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \
IN {self._denormalize_quote_join(schema)}"
)
)
else:
cursor = connection.execute(
text("SHOW /* sqlalchemy:get_temp_table_names */ TABLES")
cursor = connection.execute(
text(
f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \
IN SCHEMA {self._denormalize_quote_join(schema)}"
)
)

ret = []
n2i = self.__class__._map_name_to_idx(cursor)
Expand Down Expand Up @@ -839,62 +850,79 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
)
}

def get_multi_indexes(
def get_table_names_with_prefix(
self,
connection,
*,
schema,
filter_names,
prefix,
**kw,
):
tables_data = self._get_schema_tables_info(connection, schema, **kw)
table_names = []
for table_name, tables_data_value in tables_data.items():
if prefix in tables_data_value["prefixes"]:
table_names.append(table_name)
return table_names

def get_multi_indexes(
self,
connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw,
):
"""
Gets the indexes definition
"""

table_prefixes = self.get_multi_prefixes(
connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name
schema = schema or self.default_schema_name
hybrid_table_names = self.get_table_names_with_prefix(
connection,
schema=schema,
prefix=CustomTablePrefix.HYBRID.name,
info_cache=kw.get("info_cache", None),
)
if len(table_prefixes) == 0:
if len(hybrid_table_names) == 0:
return []
schema = schema or self.default_schema_name
if not schema:
result = connection.execute(
text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES")
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}"
)

result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}"
)
)

n2i = self.__class__._map_name_to_idx(result)
n2i = self._map_name_to_idx(result)
indexes = {}

for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["table"]]))
table_name = self.normalize_name(str(row[n2i["table"]]))
if (
row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY'
or table not in filter_names
or (schema, table) not in table_prefixes
or (
(schema, table) in table_prefixes
and CustomTablePrefix.HYBRID.name
not in table_prefixes[(schema, table)]
)
or table_name not in filter_names
or table_name not in hybrid_table_names
):
continue
index = {
"name": row[n2i["name"]],
"unique": row[n2i["is_unique"]] == "Y",
"column_names": row[n2i["columns"]],
"include_columns": row[n2i["included_columns"]],
"column_names": [
self.normalize_name(column)
for column in parse_index_columns(row[n2i["columns"]])
],
"include_columns": [
self.normalize_name(column)
for column in parse_index_columns(row[n2i["included_columns"]])
],
"dialect_options": {},
}
if (schema, table) in indexes:
indexes[(schema, table)] = indexes[(schema, table)].append(index)

if (schema, table_name) in indexes:
indexes[(schema, table_name)] = indexes[(schema, table_name)].append(
index
)
else:
indexes[(schema, table)] = [index]
indexes[(schema, table_name)] = [index]

return list(indexes.items())

Expand All @@ -906,50 +934,6 @@ def _value_or_default(self, data, table, schema):
else:
return []

def get_prefixes_from_data(self, n2i, row, **kw):
prefixes_found = []
for valid_prefix in CustomTablePrefix:
key = f"is_{valid_prefix.name.lower()}"
if key in n2i and row[n2i[key]] == "Y":
prefixes_found.append(valid_prefix.name)
return prefixes_found

@reflection.cache
def get_multi_prefixes(
self, connection, schema, table_name=None, filter_prefix=None, **kw
):
"""
Gets all table prefixes
"""
schema = schema or self.default_schema_name
filter = f"LIKE '{table_name}'" if table_name else ""
if schema:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}"
)
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'"
)
)

n2i = self.__class__._map_name_to_idx(result)
tables_prefixes = {}
for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["name"]]))
table_prefixes = self.get_prefixes_from_data(n2i, row)
if filter_prefix and filter_prefix not in table_prefixes:
continue
if (schema, table) in tables_prefixes:
tables_prefixes[(schema, table)].append(table_prefixes)
else:
tables_prefixes[(schema, table)] = table_prefixes

return tables_prefixes

@reflection.cache
def get_indexes(self, connection, tablename, schema, **kw):
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations

import logging.handlers
import os
import sys
import time
Expand Down Expand Up @@ -194,6 +195,32 @@ def engine_testaccount(request):
yield engine


@pytest.fixture()
def assert_text_in_buf():
buf = logging.handlers.BufferingHandler(100)
for log in [
logging.getLogger("sqlalchemy.engine"),
]:
log.addHandler(buf)

def go(expected, occurrences=1):
assert buf.buffer
buflines = [rec.getMessage() for rec in buf.buffer]

ocurrences_found = buflines.count(expected)
assert occurrences == ocurrences_found, (
f"Expected {occurrences} of {expected}, got {ocurrences_found} "
f"occurrences in {buflines}."
)
buf.flush()

yield go
for log in [
logging.getLogger("sqlalchemy.engine"),
]:
log.removeHandler(buf)


@pytest.fixture()
def engine_testaccount_with_numpy(request):
url = url_factory(numpy=True)
Expand Down
Loading

0 comments on commit 716683f

Please sign in to comment.